Spaces:
Paused
Paused
# | |
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
""" | |
Reference: | |
- [graphrag](https://github.com/microsoft/graphrag) | |
""" | |
import argparse | |
import html | |
import json | |
import logging | |
import numbers | |
import re | |
import traceback | |
from collections.abc import Callable | |
from dataclasses import dataclass | |
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements | |
from rag.llm.chat_model import Base as CompletionLLM | |
import networkx as nx | |
from rag.utils import num_tokens_from_string | |
SUMMARIZE_PROMPT = """ | |
You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. | |
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. | |
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. | |
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. | |
Make sure it is written in third person, and include the entity names so we the have full context. | |
####### | |
-Data- | |
Entities: {entity_name} | |
Description List: {description_list} | |
####### | |
Output: | |
""" | |
# Max token size for input prompts | |
DEFAULT_MAX_INPUT_TOKENS = 4_000 | |
# Max token count for LLM answers | |
DEFAULT_MAX_SUMMARY_LENGTH = 128 | |
class SummarizationResult: | |
"""Unipartite graph extraction result class definition.""" | |
items: str | tuple[str, str] | |
description: str | |
class SummarizeExtractor: | |
"""Unipartite graph extractor class definition.""" | |
_llm: CompletionLLM | |
_entity_name_key: str | |
_input_descriptions_key: str | |
_summarization_prompt: str | |
_on_error: ErrorHandlerFn | |
_max_summary_length: int | |
_max_input_tokens: int | |
def __init__( | |
self, | |
llm_invoker: CompletionLLM, | |
entity_name_key: str | None = None, | |
input_descriptions_key: str | None = None, | |
summarization_prompt: str | None = None, | |
on_error: ErrorHandlerFn | None = None, | |
max_summary_length: int | None = None, | |
max_input_tokens: int | None = None, | |
): | |
"""Init method definition.""" | |
# TODO: streamline construction | |
self._llm = llm_invoker | |
self._entity_name_key = entity_name_key or "entity_name" | |
self._input_descriptions_key = input_descriptions_key or "description_list" | |
self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT | |
self._on_error = on_error or (lambda _e, _s, _d: None) | |
self._max_summary_length = max_summary_length or DEFAULT_MAX_SUMMARY_LENGTH | |
self._max_input_tokens = max_input_tokens or DEFAULT_MAX_INPUT_TOKENS | |
def __call__( | |
self, | |
items: str | tuple[str, str], | |
descriptions: list[str], | |
) -> SummarizationResult: | |
"""Call method definition.""" | |
result = "" | |
if len(descriptions) == 0: | |
result = "" | |
if len(descriptions) == 1: | |
result = descriptions[0] | |
else: | |
result = self._summarize_descriptions(items, descriptions) | |
return SummarizationResult( | |
items=items, | |
description=result or "", | |
) | |
def _summarize_descriptions( | |
self, items: str | tuple[str, str], descriptions: list[str] | |
) -> str: | |
"""Summarize descriptions into a single description.""" | |
sorted_items = sorted(items) if isinstance(items, list) else items | |
# Safety check, should always be a list | |
if not isinstance(descriptions, list): | |
descriptions = [descriptions] | |
# Iterate over descriptions, adding all until the max input tokens is reached | |
usable_tokens = self._max_input_tokens - num_tokens_from_string( | |
self._summarization_prompt | |
) | |
descriptions_collected = [] | |
result = "" | |
for i, description in enumerate(descriptions): | |
usable_tokens -= num_tokens_from_string(description) | |
descriptions_collected.append(description) | |
# If buffer is full, or all descriptions have been added, summarize | |
if (usable_tokens < 0 and len(descriptions_collected) > 1) or ( | |
i == len(descriptions) - 1 | |
): | |
# Calculate result (final or partial) | |
result = await self._summarize_descriptions_with_llm( | |
sorted_items, descriptions_collected | |
) | |
# If we go for another loop, reset values to new | |
if i != len(descriptions) - 1: | |
descriptions_collected = [result] | |
usable_tokens = ( | |
self._max_input_tokens | |
- num_tokens_from_string(self._summarization_prompt) | |
- num_tokens_from_string(result) | |
) | |
return result | |
def _summarize_descriptions_with_llm( | |
self, items: str | tuple[str, str] | list[str], descriptions: list[str] | |
): | |
"""Summarize descriptions using the LLM.""" | |
variables = { | |
self._entity_name_key: json.dumps(items), | |
self._input_descriptions_key: json.dumps(sorted(descriptions)), | |
} | |
text = perform_variable_replacements(self._summarization_prompt, variables=variables) | |
return self._llm.chat("", [{"role": "user", "content": text}]) | |