Spaces:
Running
Running
File size: 9,526 Bytes
5e9cd1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
from typing import List, Optional
from langchain.schema.language_model import BaseLanguageModel
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from configs import (logger)
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser
from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain, MapReduceDocumentsChain
import sys
import asyncio
class SummaryAdapter:
_OVERLAP_SIZE: int
token_max: int
_separator: str = "\n\n"
chain: MapReduceDocumentsChain
def __init__(self, overlap_size: int, token_max: int,
chain: MapReduceDocumentsChain):
self._OVERLAP_SIZE = overlap_size
self.chain = chain
self.token_max = token_max
@classmethod
def form_summary(cls,
llm: BaseLanguageModel,
reduce_llm: BaseLanguageModel,
overlap_size: int,
token_max: int = 1300):
"""
获取实例
:param reduce_llm: 用于合并摘要的llm
:param llm: 用于生成摘要的llm
:param overlap_size: 重叠部分大小
:param token_max: 最大的chunk数量,每个chunk长度小于token_max长度,第一次生成摘要时,大于token_max长度的摘要会报错
:return:
"""
# This controls how each document will be formatted. Specifically,
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
# The prompt here should take as an input variable the
# `document_variable_name`
prompt_template = (
"根据文本执行任务。以下任务信息"
"{task_briefing}"
"文本内容如下: "
"\r\n"
"{context}"
)
prompt = PromptTemplate(
template=prompt_template,
input_variables=["task_briefing", "context"]
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
# We now define how to combine these summaries
reduce_prompt = PromptTemplate.from_template(
"Combine these summaries: {context}"
)
reduce_llm_chain = LLMChain(llm=reduce_llm, prompt=reduce_prompt)
document_variable_name = "context"
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
reduce_documents_chain = ReduceDocumentsChain(
token_max=token_max,
combine_documents_chain=combine_documents_chain,
)
chain = MapReduceDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
reduce_documents_chain=reduce_documents_chain,
# 返回中间步骤
return_intermediate_steps=True
)
return cls(overlap_size=overlap_size,
chain=chain,
token_max=token_max)
def summarize(self,
file_description: str,
docs: List[DocumentWithVSId] = []
) -> List[Document]:
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
else:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 同步调用协程代码
return loop.run_until_complete(self.asummarize(file_description=file_description,
docs=docs))
async def asummarize(self,
file_description: str,
docs: List[DocumentWithVSId] = []) -> List[Document]:
logger.info("start summary")
"""
这个过程分成两个部分:
1. 对每个文档进行处理,得到每个文档的摘要
map_results = self.llm_chain.apply(
# FYI - this is parallelized and so it is fast.
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
callbacks=callbacks,
)
2. 对每个文档的摘要进行合并,得到最终的摘要,return_intermediate_steps=True,返回中间步骤
result, extra_return_dict = self.reduce_documents_chain.combine_docs(
result_docs, token_max=token_max, callbacks=callbacks, **kwargs
)
"""
summary_combine, summary_intermediate_steps = self.chain.combine_docs(docs=docs,
task_briefing="描述不同方法之间的接近度和相似性,"
"以帮助读者理解它们之间的关系。")
print(summary_combine)
print(summary_intermediate_steps)
# if len(summary_combine) == 0:
# # 为空重新生成,数量减半
# result_docs = [
# Document(page_content=question_result_key, metadata=docs[i].metadata)
# # This uses metadata from the docs, and the textual results from `results`
# for i, question_result_key in enumerate(
# summary_intermediate_steps["intermediate_steps"][
# :len(summary_intermediate_steps["intermediate_steps"]) // 2
# ])
# ]
# summary_combine, summary_intermediate_steps = self.chain.reduce_documents_chain.combine_docs(
# result_docs, token_max=self.token_max
# )
logger.info("end summary")
doc_ids = ",".join([doc.id for doc in docs])
_metadata = {
"file_description": file_description,
"summary_intermediate_steps": summary_intermediate_steps,
"doc_ids": doc_ids
}
summary_combine_doc = Document(page_content=summary_combine, metadata=_metadata)
return [summary_combine_doc]
def _drop_overlap(self, docs: List[DocumentWithVSId]) -> List[str]:
"""
# 将文档中page_content句子叠加的部分去掉
:param docs:
:param separator:
:return:
"""
merge_docs = []
pre_doc = None
for doc in docs:
# 第一个文档直接添加
if len(merge_docs) == 0:
pre_doc = doc.page_content
merge_docs.append(doc.page_content)
continue
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
# 迭代递减pre_doc的长度,每次迭代删除前面的字符,
# 查询重叠部分,直到pre_doc的长度小于 self._OVERLAP_SIZE // 2 - 2len(separator)
for i in range(len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1):
# 每次迭代删除前面的字符
pre_doc = pre_doc[1:]
if doc.page_content[:len(pre_doc)] == pre_doc:
# 删除下一个开头重叠的部分
merge_docs.append(doc.page_content[len(pre_doc):])
break
pre_doc = doc.page_content
return merge_docs
def _join_docs(self, docs: List[str]) -> Optional[str]:
text = self._separator.join(docs)
text = text.strip()
if text == "":
return None
else:
return text
if __name__ == '__main__':
docs = [
'梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的',
'梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象',
'使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各',
'值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全'
]
_OVERLAP_SIZE = 1
separator: str = "\n\n"
merge_docs = []
# 将文档中page_content句子叠加的部分去掉,
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
pre_doc = None
for doc in docs:
# 第一个文档直接添加
if len(merge_docs) == 0:
pre_doc = doc
merge_docs.append(doc)
continue
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
# 迭代递减pre_doc的长度,每次迭代删除前面的字符,
# 查询重叠部分,直到pre_doc的长度小于 _OVERLAP_SIZE-2len(separator)
for i in range(len(pre_doc), _OVERLAP_SIZE - 2 * len(separator), -1):
# 每次迭代删除前面的字符
pre_doc = pre_doc[1:]
if doc[:len(pre_doc)] == pre_doc:
# 删除下一个开头重叠的部分
page_content = doc[len(pre_doc):]
merge_docs.append(page_content)
pre_doc = doc
break
# 将merge_docs中的句子合并成一个文档
text = separator.join(merge_docs)
text = text.strip()
print(text)
|