Spaces:
Sleeping
Sleeping
cbys4
commited on
Commit
·
2785a12
unverified
·
0
Parent(s):
Add files via upload
Browse files- PathRAG/PathRAG.py +562 -0
- PathRAG/__init__.py +3 -0
- PathRAG/__pycache__/PathRAG.cpython-39.pyc +0 -0
- PathRAG/__pycache__/__init__.cpython-39.pyc +0 -0
- PathRAG/__pycache__/base.cpython-39.pyc +0 -0
- PathRAG/__pycache__/llm.cpython-39.pyc +0 -0
- PathRAG/__pycache__/operate.cpython-39.pyc +0 -0
- PathRAG/__pycache__/prompt.cpython-39.pyc +0 -0
- PathRAG/__pycache__/storage.cpython-39.pyc +0 -0
- PathRAG/__pycache__/utils.cpython-39.pyc +0 -0
- PathRAG/base.py +135 -0
- PathRAG/llm.py +1036 -0
- PathRAG/operate.py +1239 -0
- PathRAG/prompt.py +286 -0
- PathRAG/storage.py +341 -0
- PathRAG/utils.py +527 -0
- PathRAG/v1_test.py +49 -0
- requirements.txt +28 -0
PathRAG/PathRAG.py
ADDED
|
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
| 4 |
+
from dataclasses import asdict, dataclass, field
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from functools import partial
|
| 7 |
+
from typing import Type, cast
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from .llm import (
|
| 11 |
+
gpt_4o_mini_complete,
|
| 12 |
+
openai_embedding,
|
| 13 |
+
)
|
| 14 |
+
from .operate import (
|
| 15 |
+
chunking_by_token_size,
|
| 16 |
+
extract_entities,
|
| 17 |
+
kg_query,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from .utils import (
|
| 21 |
+
EmbeddingFunc,
|
| 22 |
+
compute_mdhash_id,
|
| 23 |
+
limit_async_func_call,
|
| 24 |
+
convert_response_to_json,
|
| 25 |
+
logger,
|
| 26 |
+
set_logger,
|
| 27 |
+
)
|
| 28 |
+
from .base import (
|
| 29 |
+
BaseGraphStorage,
|
| 30 |
+
BaseKVStorage,
|
| 31 |
+
BaseVectorStorage,
|
| 32 |
+
StorageNameSpace,
|
| 33 |
+
QueryParam,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
from .storage import (
|
| 37 |
+
JsonKVStorage,
|
| 38 |
+
NanoVectorDBStorage,
|
| 39 |
+
NetworkXStorage,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def lazy_external_import(module_name: str, class_name: str):
|
| 46 |
+
"""Lazily import a class from an external module based on the package of the caller."""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
import inspect
|
| 50 |
+
|
| 51 |
+
caller_frame = inspect.currentframe().f_back
|
| 52 |
+
module = inspect.getmodule(caller_frame)
|
| 53 |
+
package = module.__package__ if module else None
|
| 54 |
+
|
| 55 |
+
def import_class(*args, **kwargs):
|
| 56 |
+
import importlib
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
module = importlib.import_module(module_name, package=package)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
cls = getattr(module, class_name)
|
| 63 |
+
return cls(*args, **kwargs)
|
| 64 |
+
|
| 65 |
+
return import_class
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage")
|
| 69 |
+
OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage")
|
| 70 |
+
OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage")
|
| 71 |
+
OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage")
|
| 72 |
+
MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
|
| 73 |
+
MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
|
| 74 |
+
ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
|
| 75 |
+
TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
|
| 76 |
+
TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
|
| 77 |
+
AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
| 81 |
+
"""
|
| 82 |
+
Ensure that there is always an event loop available.
|
| 83 |
+
|
| 84 |
+
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
| 85 |
+
it creates a new event loop and sets it as the current event loop.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
asyncio.AbstractEventLoop: The current or newly created event loop.
|
| 89 |
+
"""
|
| 90 |
+
try:
|
| 91 |
+
|
| 92 |
+
current_loop = asyncio.get_event_loop()
|
| 93 |
+
if current_loop.is_closed():
|
| 94 |
+
raise RuntimeError("Event loop is closed.")
|
| 95 |
+
return current_loop
|
| 96 |
+
|
| 97 |
+
except RuntimeError:
|
| 98 |
+
|
| 99 |
+
logger.info("Creating a new event loop in main thread.")
|
| 100 |
+
new_loop = asyncio.new_event_loop()
|
| 101 |
+
asyncio.set_event_loop(new_loop)
|
| 102 |
+
return new_loop
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclass
|
| 106 |
+
class PathRAG:
|
| 107 |
+
working_dir: str = field(
|
| 108 |
+
default_factory=lambda: f"./PathRAG_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
embedding_cache_config: dict = field(
|
| 112 |
+
default_factory=lambda: {
|
| 113 |
+
"enabled": False,
|
| 114 |
+
"similarity_threshold": 0.95,
|
| 115 |
+
"use_llm_check": False,
|
| 116 |
+
}
|
| 117 |
+
)
|
| 118 |
+
kv_storage: str = field(default="JsonKVStorage")
|
| 119 |
+
vector_storage: str = field(default="NanoVectorDBStorage")
|
| 120 |
+
graph_storage: str = field(default="NetworkXStorage")
|
| 121 |
+
|
| 122 |
+
current_log_level = logger.level
|
| 123 |
+
log_level: str = field(default=current_log_level)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
chunk_token_size: int = 1200
|
| 127 |
+
chunk_overlap_token_size: int = 100
|
| 128 |
+
tiktoken_model_name: str = "gpt-4o-mini"
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
entity_extract_max_gleaning: int = 1
|
| 132 |
+
entity_summary_to_max_tokens: int = 500
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
node_embedding_algorithm: str = "node2vec"
|
| 136 |
+
node2vec_params: dict = field(
|
| 137 |
+
default_factory=lambda: {
|
| 138 |
+
"dimensions": 1536,
|
| 139 |
+
"num_walks": 10,
|
| 140 |
+
"walk_length": 40,
|
| 141 |
+
"window_size": 2,
|
| 142 |
+
"iterations": 3,
|
| 143 |
+
"random_seed": 3,
|
| 144 |
+
}
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
| 149 |
+
embedding_batch_num: int = 32
|
| 150 |
+
embedding_func_max_async: int = 16
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
llm_model_func: callable = gpt_4o_mini_complete
|
| 154 |
+
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
|
| 155 |
+
llm_model_max_token_size: int = 32768
|
| 156 |
+
llm_model_max_async: int = 16
|
| 157 |
+
llm_model_kwargs: dict = field(default_factory=dict)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
| 161 |
+
|
| 162 |
+
enable_llm_cache: bool = True
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
addon_params: dict = field(default_factory=dict)
|
| 166 |
+
convert_response_to_json_func: callable = convert_response_to_json
|
| 167 |
+
|
| 168 |
+
def __post_init__(self):
|
| 169 |
+
log_file = os.path.join("PathRAG.log")
|
| 170 |
+
set_logger(log_file)
|
| 171 |
+
logger.setLevel(self.log_level)
|
| 172 |
+
|
| 173 |
+
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
|
| 177 |
+
self._get_storage_class()[self.kv_storage]
|
| 178 |
+
)
|
| 179 |
+
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
|
| 180 |
+
self.vector_storage
|
| 181 |
+
]
|
| 182 |
+
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
|
| 183 |
+
self.graph_storage
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
if not os.path.exists(self.working_dir):
|
| 187 |
+
logger.info(f"Creating working directory {self.working_dir}")
|
| 188 |
+
os.makedirs(self.working_dir)
|
| 189 |
+
|
| 190 |
+
self.llm_response_cache = (
|
| 191 |
+
self.key_string_value_json_storage_cls(
|
| 192 |
+
namespace="llm_response_cache",
|
| 193 |
+
global_config=asdict(self),
|
| 194 |
+
embedding_func=None,
|
| 195 |
+
)
|
| 196 |
+
if self.enable_llm_cache
|
| 197 |
+
else None
|
| 198 |
+
)
|
| 199 |
+
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
| 200 |
+
self.embedding_func
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
self.full_docs = self.key_string_value_json_storage_cls(
|
| 205 |
+
namespace="full_docs",
|
| 206 |
+
global_config=asdict(self),
|
| 207 |
+
embedding_func=self.embedding_func,
|
| 208 |
+
)
|
| 209 |
+
self.text_chunks = self.key_string_value_json_storage_cls(
|
| 210 |
+
namespace="text_chunks",
|
| 211 |
+
global_config=asdict(self),
|
| 212 |
+
embedding_func=self.embedding_func,
|
| 213 |
+
)
|
| 214 |
+
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
| 215 |
+
namespace="chunk_entity_relation",
|
| 216 |
+
global_config=asdict(self),
|
| 217 |
+
embedding_func=self.embedding_func,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
self.entities_vdb = self.vector_db_storage_cls(
|
| 222 |
+
namespace="entities",
|
| 223 |
+
global_config=asdict(self),
|
| 224 |
+
embedding_func=self.embedding_func,
|
| 225 |
+
meta_fields={"entity_name"},
|
| 226 |
+
)
|
| 227 |
+
self.relationships_vdb = self.vector_db_storage_cls(
|
| 228 |
+
namespace="relationships",
|
| 229 |
+
global_config=asdict(self),
|
| 230 |
+
embedding_func=self.embedding_func,
|
| 231 |
+
meta_fields={"src_id", "tgt_id"},
|
| 232 |
+
)
|
| 233 |
+
self.chunks_vdb = self.vector_db_storage_cls(
|
| 234 |
+
namespace="chunks",
|
| 235 |
+
global_config=asdict(self),
|
| 236 |
+
embedding_func=self.embedding_func,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
| 240 |
+
partial(
|
| 241 |
+
self.llm_model_func,
|
| 242 |
+
hashing_kv=self.llm_response_cache
|
| 243 |
+
if self.llm_response_cache
|
| 244 |
+
and hasattr(self.llm_response_cache, "global_config")
|
| 245 |
+
else self.key_string_value_json_storage_cls(
|
| 246 |
+
global_config=asdict(self),
|
| 247 |
+
),
|
| 248 |
+
**self.llm_model_kwargs,
|
| 249 |
+
)
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
def _get_storage_class(self) -> Type[BaseGraphStorage]:
|
| 253 |
+
return {
|
| 254 |
+
|
| 255 |
+
"JsonKVStorage": JsonKVStorage,
|
| 256 |
+
"OracleKVStorage": OracleKVStorage,
|
| 257 |
+
"MongoKVStorage": MongoKVStorage,
|
| 258 |
+
"TiDBKVStorage": TiDBKVStorage,
|
| 259 |
+
|
| 260 |
+
"NanoVectorDBStorage": NanoVectorDBStorage,
|
| 261 |
+
"OracleVectorDBStorage": OracleVectorDBStorage,
|
| 262 |
+
"MilvusVectorDBStorge": MilvusVectorDBStorge,
|
| 263 |
+
"ChromaVectorDBStorage": ChromaVectorDBStorage,
|
| 264 |
+
"TiDBVectorDBStorage": TiDBVectorDBStorage,
|
| 265 |
+
|
| 266 |
+
"NetworkXStorage": NetworkXStorage,
|
| 267 |
+
"Neo4JStorage": Neo4JStorage,
|
| 268 |
+
"OracleGraphStorage": OracleGraphStorage,
|
| 269 |
+
"AGEStorage": AGEStorage,
|
| 270 |
+
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
def insert(self, string_or_strings):
|
| 274 |
+
|
| 275 |
+
loop = always_get_an_event_loop()
|
| 276 |
+
return loop.run_until_complete(self.ainsert(string_or_strings))
|
| 277 |
+
|
| 278 |
+
async def ainsert(self, string_or_strings):
|
| 279 |
+
update_storage = False
|
| 280 |
+
try:
|
| 281 |
+
if isinstance(string_or_strings, str):
|
| 282 |
+
string_or_strings = [string_or_strings]
|
| 283 |
+
|
| 284 |
+
new_docs = {
|
| 285 |
+
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
|
| 286 |
+
for c in string_or_strings
|
| 287 |
+
}
|
| 288 |
+
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
| 289 |
+
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
| 290 |
+
if not len(new_docs):
|
| 291 |
+
logger.warning("All docs are already in the storage")
|
| 292 |
+
return
|
| 293 |
+
update_storage = True
|
| 294 |
+
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
| 295 |
+
|
| 296 |
+
inserting_chunks = {}
|
| 297 |
+
for doc_key, doc in tqdm_async(
|
| 298 |
+
new_docs.items(), desc="Chunking documents", unit="doc"
|
| 299 |
+
):
|
| 300 |
+
chunks = {
|
| 301 |
+
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
| 302 |
+
**dp,
|
| 303 |
+
"full_doc_id": doc_key,
|
| 304 |
+
}
|
| 305 |
+
for dp in chunking_by_token_size(
|
| 306 |
+
doc["content"],
|
| 307 |
+
overlap_token_size=self.chunk_overlap_token_size,
|
| 308 |
+
max_token_size=self.chunk_token_size,
|
| 309 |
+
tiktoken_model=self.tiktoken_model_name,
|
| 310 |
+
)
|
| 311 |
+
}
|
| 312 |
+
inserting_chunks.update(chunks)
|
| 313 |
+
_add_chunk_keys = await self.text_chunks.filter_keys(
|
| 314 |
+
list(inserting_chunks.keys())
|
| 315 |
+
)
|
| 316 |
+
inserting_chunks = {
|
| 317 |
+
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
| 318 |
+
}
|
| 319 |
+
if not len(inserting_chunks):
|
| 320 |
+
logger.warning("All chunks are already in the storage")
|
| 321 |
+
return
|
| 322 |
+
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
| 323 |
+
|
| 324 |
+
await self.chunks_vdb.upsert(inserting_chunks)
|
| 325 |
+
|
| 326 |
+
logger.info("[Entity Extraction]...")
|
| 327 |
+
maybe_new_kg = await extract_entities(
|
| 328 |
+
inserting_chunks,
|
| 329 |
+
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
| 330 |
+
entity_vdb=self.entities_vdb,
|
| 331 |
+
relationships_vdb=self.relationships_vdb,
|
| 332 |
+
global_config=asdict(self),
|
| 333 |
+
)
|
| 334 |
+
if maybe_new_kg is None:
|
| 335 |
+
logger.warning("No new entities and relationships found")
|
| 336 |
+
return
|
| 337 |
+
self.chunk_entity_relation_graph = maybe_new_kg
|
| 338 |
+
|
| 339 |
+
await self.full_docs.upsert(new_docs)
|
| 340 |
+
await self.text_chunks.upsert(inserting_chunks)
|
| 341 |
+
finally:
|
| 342 |
+
if update_storage:
|
| 343 |
+
await self._insert_done()
|
| 344 |
+
|
| 345 |
+
async def _insert_done(self):
|
| 346 |
+
tasks = []
|
| 347 |
+
for storage_inst in [
|
| 348 |
+
self.full_docs,
|
| 349 |
+
self.text_chunks,
|
| 350 |
+
self.llm_response_cache,
|
| 351 |
+
self.entities_vdb,
|
| 352 |
+
self.relationships_vdb,
|
| 353 |
+
self.chunks_vdb,
|
| 354 |
+
self.chunk_entity_relation_graph,
|
| 355 |
+
]:
|
| 356 |
+
if storage_inst is None:
|
| 357 |
+
continue
|
| 358 |
+
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
| 359 |
+
await asyncio.gather(*tasks)
|
| 360 |
+
|
| 361 |
+
def insert_custom_kg(self, custom_kg: dict):
|
| 362 |
+
loop = always_get_an_event_loop()
|
| 363 |
+
return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
|
| 364 |
+
|
| 365 |
+
async def ainsert_custom_kg(self, custom_kg: dict):
|
| 366 |
+
update_storage = False
|
| 367 |
+
try:
|
| 368 |
+
|
| 369 |
+
all_chunks_data = {}
|
| 370 |
+
chunk_to_source_map = {}
|
| 371 |
+
for chunk_data in custom_kg.get("chunks", []):
|
| 372 |
+
chunk_content = chunk_data["content"]
|
| 373 |
+
source_id = chunk_data["source_id"]
|
| 374 |
+
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
|
| 375 |
+
|
| 376 |
+
chunk_entry = {"content": chunk_content.strip(), "source_id": source_id}
|
| 377 |
+
all_chunks_data[chunk_id] = chunk_entry
|
| 378 |
+
chunk_to_source_map[source_id] = chunk_id
|
| 379 |
+
update_storage = True
|
| 380 |
+
|
| 381 |
+
if self.chunks_vdb is not None and all_chunks_data:
|
| 382 |
+
await self.chunks_vdb.upsert(all_chunks_data)
|
| 383 |
+
if self.text_chunks is not None and all_chunks_data:
|
| 384 |
+
await self.text_chunks.upsert(all_chunks_data)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
all_entities_data = []
|
| 388 |
+
for entity_data in custom_kg.get("entities", []):
|
| 389 |
+
entity_name = f'"{entity_data["entity_name"].upper()}"'
|
| 390 |
+
entity_type = entity_data.get("entity_type", "UNKNOWN")
|
| 391 |
+
description = entity_data.get("description", "No description provided")
|
| 392 |
+
|
| 393 |
+
source_chunk_id = entity_data.get("source_id", "UNKNOWN")
|
| 394 |
+
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
if source_id == "UNKNOWN":
|
| 398 |
+
logger.warning(
|
| 399 |
+
f"Entity '{entity_name}' has an UNKNOWN source_id. Please check the source mapping."
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
node_data = {
|
| 404 |
+
"entity_type": entity_type,
|
| 405 |
+
"description": description,
|
| 406 |
+
"source_id": source_id,
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
await self.chunk_entity_relation_graph.upsert_node(
|
| 410 |
+
entity_name, node_data=node_data
|
| 411 |
+
)
|
| 412 |
+
node_data["entity_name"] = entity_name
|
| 413 |
+
all_entities_data.append(node_data)
|
| 414 |
+
update_storage = True
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
all_relationships_data = []
|
| 418 |
+
for relationship_data in custom_kg.get("relationships", []):
|
| 419 |
+
src_id = f'"{relationship_data["src_id"].upper()}"'
|
| 420 |
+
tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
|
| 421 |
+
description = relationship_data["description"]
|
| 422 |
+
keywords = relationship_data["keywords"]
|
| 423 |
+
weight = relationship_data.get("weight", 1.0)
|
| 424 |
+
|
| 425 |
+
source_chunk_id = relationship_data.get("source_id", "UNKNOWN")
|
| 426 |
+
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
if source_id == "UNKNOWN":
|
| 430 |
+
logger.warning(
|
| 431 |
+
f"Relationship from '{src_id}' to '{tgt_id}' has an UNKNOWN source_id. Please check the source mapping."
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
for need_insert_id in [src_id, tgt_id]:
|
| 436 |
+
if not (
|
| 437 |
+
await self.chunk_entity_relation_graph.has_node(need_insert_id)
|
| 438 |
+
):
|
| 439 |
+
await self.chunk_entity_relation_graph.upsert_node(
|
| 440 |
+
need_insert_id,
|
| 441 |
+
node_data={
|
| 442 |
+
"source_id": source_id,
|
| 443 |
+
"description": "UNKNOWN",
|
| 444 |
+
"entity_type": "UNKNOWN",
|
| 445 |
+
},
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
await self.chunk_entity_relation_graph.upsert_edge(
|
| 450 |
+
src_id,
|
| 451 |
+
tgt_id,
|
| 452 |
+
edge_data={
|
| 453 |
+
"weight": weight,
|
| 454 |
+
"description": description,
|
| 455 |
+
"keywords": keywords,
|
| 456 |
+
"source_id": source_id,
|
| 457 |
+
},
|
| 458 |
+
)
|
| 459 |
+
edge_data = {
|
| 460 |
+
"src_id": src_id,
|
| 461 |
+
"tgt_id": tgt_id,
|
| 462 |
+
"description": description,
|
| 463 |
+
"keywords": keywords,
|
| 464 |
+
}
|
| 465 |
+
all_relationships_data.append(edge_data)
|
| 466 |
+
update_storage = True
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
if self.entities_vdb is not None:
|
| 470 |
+
data_for_vdb = {
|
| 471 |
+
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
| 472 |
+
"content": dp["entity_name"] + dp["description"],
|
| 473 |
+
"entity_name": dp["entity_name"],
|
| 474 |
+
}
|
| 475 |
+
for dp in all_entities_data
|
| 476 |
+
}
|
| 477 |
+
await self.entities_vdb.upsert(data_for_vdb)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
if self.relationships_vdb is not None:
|
| 481 |
+
data_for_vdb = {
|
| 482 |
+
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
| 483 |
+
"src_id": dp["src_id"],
|
| 484 |
+
"tgt_id": dp["tgt_id"],
|
| 485 |
+
"content": dp["keywords"]
|
| 486 |
+
+ dp["src_id"]
|
| 487 |
+
+ dp["tgt_id"]
|
| 488 |
+
+ dp["description"],
|
| 489 |
+
}
|
| 490 |
+
for dp in all_relationships_data
|
| 491 |
+
}
|
| 492 |
+
await self.relationships_vdb.upsert(data_for_vdb)
|
| 493 |
+
finally:
|
| 494 |
+
if update_storage:
|
| 495 |
+
await self._insert_done()
|
| 496 |
+
|
| 497 |
+
def query(self, query: str, param: QueryParam = QueryParam()):
|
| 498 |
+
loop = always_get_an_event_loop()
|
| 499 |
+
return loop.run_until_complete(self.aquery(query, param))
|
| 500 |
+
|
| 501 |
+
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
| 502 |
+
if param.mode in ["hybrid"]:
|
| 503 |
+
response= await kg_query(
|
| 504 |
+
query,
|
| 505 |
+
self.chunk_entity_relation_graph,
|
| 506 |
+
self.entities_vdb,
|
| 507 |
+
self.relationships_vdb,
|
| 508 |
+
self.text_chunks,
|
| 509 |
+
param,
|
| 510 |
+
asdict(self),
|
| 511 |
+
hashing_kv=self.llm_response_cache
|
| 512 |
+
if self.llm_response_cache
|
| 513 |
+
and hasattr(self.llm_response_cache, "global_config")
|
| 514 |
+
else self.key_string_value_json_storage_cls(
|
| 515 |
+
global_config=asdict(self),
|
| 516 |
+
),
|
| 517 |
+
)
|
| 518 |
+
print("response all ready")
|
| 519 |
+
else:
|
| 520 |
+
raise ValueError(f"Unknown mode {param.mode}")
|
| 521 |
+
await self._query_done()
|
| 522 |
+
return response
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
async def _query_done(self):
|
| 526 |
+
tasks = []
|
| 527 |
+
for storage_inst in [self.llm_response_cache]:
|
| 528 |
+
if storage_inst is None:
|
| 529 |
+
continue
|
| 530 |
+
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
| 531 |
+
await asyncio.gather(*tasks)
|
| 532 |
+
|
| 533 |
+
def delete_by_entity(self, entity_name: str):
|
| 534 |
+
loop = always_get_an_event_loop()
|
| 535 |
+
return loop.run_until_complete(self.adelete_by_entity(entity_name))
|
| 536 |
+
|
| 537 |
+
async def adelete_by_entity(self, entity_name: str):
|
| 538 |
+
entity_name = f'"{entity_name.upper()}"'
|
| 539 |
+
|
| 540 |
+
try:
|
| 541 |
+
await self.entities_vdb.delete_entity(entity_name)
|
| 542 |
+
await self.relationships_vdb.delete_relation(entity_name)
|
| 543 |
+
await self.chunk_entity_relation_graph.delete_node(entity_name)
|
| 544 |
+
|
| 545 |
+
logger.info(
|
| 546 |
+
f"Entity '{entity_name}' and its relationships have been deleted."
|
| 547 |
+
)
|
| 548 |
+
await self._delete_by_entity_done()
|
| 549 |
+
except Exception as e:
|
| 550 |
+
logger.error(f"Error while deleting entity '{entity_name}': {e}")
|
| 551 |
+
|
| 552 |
+
async def _delete_by_entity_done(self):
|
| 553 |
+
tasks = []
|
| 554 |
+
for storage_inst in [
|
| 555 |
+
self.entities_vdb,
|
| 556 |
+
self.relationships_vdb,
|
| 557 |
+
self.chunk_entity_relation_graph,
|
| 558 |
+
]:
|
| 559 |
+
if storage_inst is None:
|
| 560 |
+
continue
|
| 561 |
+
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
| 562 |
+
await asyncio.gather(*tasks)
|
PathRAG/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .PathRAG import PathRAG as PathRAG, QueryParam as QueryParam
|
| 2 |
+
|
| 3 |
+
|
PathRAG/__pycache__/PathRAG.cpython-39.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
PathRAG/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (188 Bytes). View file
|
|
|
PathRAG/__pycache__/base.cpython-39.pyc
ADDED
|
Binary file (6.31 kB). View file
|
|
|
PathRAG/__pycache__/llm.cpython-39.pyc
ADDED
|
Binary file (26.9 kB). View file
|
|
|
PathRAG/__pycache__/operate.cpython-39.pyc
ADDED
|
Binary file (29.5 kB). View file
|
|
|
PathRAG/__pycache__/prompt.cpython-39.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
PathRAG/__pycache__/storage.cpython-39.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
PathRAG/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (15.1 kB). View file
|
|
|
PathRAG/base.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import TypedDict, Union, Literal, Generic, TypeVar
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from .utils import EmbeddingFunc
|
| 7 |
+
|
| 8 |
+
TextChunkSchema = TypedDict(
|
| 9 |
+
"TextChunkSchema",
|
| 10 |
+
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
T = TypeVar("T")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class QueryParam:
|
| 18 |
+
mode: Literal["hybrid"] = "global"
|
| 19 |
+
only_need_context: bool = False
|
| 20 |
+
only_need_prompt: bool = False
|
| 21 |
+
response_type: str = "Multiple Paragraphs"
|
| 22 |
+
stream: bool = False
|
| 23 |
+
top_k: int =40
|
| 24 |
+
max_token_for_text_unit: int = 4000
|
| 25 |
+
max_token_for_global_context: int = 3000
|
| 26 |
+
max_token_for_local_context: int = 5000
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class StorageNameSpace:
|
| 31 |
+
namespace: str
|
| 32 |
+
global_config: dict
|
| 33 |
+
|
| 34 |
+
async def index_done_callback(self):
|
| 35 |
+
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
async def query_done_callback(self):
|
| 39 |
+
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class BaseVectorStorage(StorageNameSpace):
|
| 45 |
+
embedding_func: EmbeddingFunc
|
| 46 |
+
meta_fields: set = field(default_factory=set)
|
| 47 |
+
|
| 48 |
+
async def query(self, query: str, top_k: int) -> list[dict]:
|
| 49 |
+
raise NotImplementedError
|
| 50 |
+
|
| 51 |
+
async def upsert(self, data: dict[str, dict]):
|
| 52 |
+
|
| 53 |
+
raise NotImplementedError
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class BaseKVStorage(Generic[T], StorageNameSpace):
|
| 58 |
+
embedding_func: EmbeddingFunc
|
| 59 |
+
|
| 60 |
+
async def all_keys(self) -> list[str]:
|
| 61 |
+
raise NotImplementedError
|
| 62 |
+
|
| 63 |
+
async def get_by_id(self, id: str) -> Union[T, None]:
|
| 64 |
+
raise NotImplementedError
|
| 65 |
+
|
| 66 |
+
async def get_by_ids(
|
| 67 |
+
self, ids: list[str], fields: Union[set[str], None] = None
|
| 68 |
+
) -> list[Union[T, None]]:
|
| 69 |
+
raise NotImplementedError
|
| 70 |
+
|
| 71 |
+
async def filter_keys(self, data: list[str]) -> set[str]:
|
| 72 |
+
|
| 73 |
+
raise NotImplementedError
|
| 74 |
+
|
| 75 |
+
async def upsert(self, data: dict[str, T]):
|
| 76 |
+
raise NotImplementedError
|
| 77 |
+
|
| 78 |
+
async def drop(self):
|
| 79 |
+
raise NotImplementedError
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclass
|
| 83 |
+
class BaseGraphStorage(StorageNameSpace):
|
| 84 |
+
embedding_func: EmbeddingFunc = None
|
| 85 |
+
|
| 86 |
+
async def has_node(self, node_id: str) -> bool:
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
|
| 89 |
+
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 90 |
+
raise NotImplementedError
|
| 91 |
+
|
| 92 |
+
async def node_degree(self, node_id: str) -> int:
|
| 93 |
+
raise NotImplementedError
|
| 94 |
+
|
| 95 |
+
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 96 |
+
raise NotImplementedError
|
| 97 |
+
|
| 98 |
+
async def get_pagerank(self,node_id:str) -> float:
|
| 99 |
+
raise NotImplementedError
|
| 100 |
+
|
| 101 |
+
async def get_node(self, node_id: str) -> Union[dict, None]:
|
| 102 |
+
raise NotImplementedError
|
| 103 |
+
|
| 104 |
+
async def get_edge(
|
| 105 |
+
self, source_node_id: str, target_node_id: str
|
| 106 |
+
) -> Union[dict, None]:
|
| 107 |
+
raise NotImplementedError
|
| 108 |
+
|
| 109 |
+
async def get_node_edges(
|
| 110 |
+
self, source_node_id: str
|
| 111 |
+
) -> Union[list[tuple[str, str]], None]:
|
| 112 |
+
raise NotImplementedError
|
| 113 |
+
|
| 114 |
+
async def get_node_in_edges(
|
| 115 |
+
self,source_node_id:str
|
| 116 |
+
) -> Union[list[tuple[str,str]],None]:
|
| 117 |
+
raise NotImplementedError
|
| 118 |
+
async def get_node_out_edges(
|
| 119 |
+
self,source_node_id:str
|
| 120 |
+
) -> Union[list[tuple[str,str]],None]:
|
| 121 |
+
raise NotImplementedError
|
| 122 |
+
|
| 123 |
+
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 124 |
+
raise NotImplementedError
|
| 125 |
+
|
| 126 |
+
async def upsert_edge(
|
| 127 |
+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 128 |
+
):
|
| 129 |
+
raise NotImplementedError
|
| 130 |
+
|
| 131 |
+
async def delete_node(self, node_id: str):
|
| 132 |
+
raise NotImplementedError
|
| 133 |
+
|
| 134 |
+
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
| 135 |
+
raise NotImplementedError("Node embedding is not used in PathRag.")
|
PathRAG/llm.py
ADDED
|
@@ -0,0 +1,1036 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import copy
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import struct
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
from typing import List, Dict, Callable, Any, Union, Optional
|
| 9 |
+
import aioboto3
|
| 10 |
+
import aiohttp
|
| 11 |
+
import numpy as np
|
| 12 |
+
import ollama
|
| 13 |
+
import torch
|
| 14 |
+
import time
|
| 15 |
+
from openai import (
|
| 16 |
+
AsyncOpenAI,
|
| 17 |
+
APIConnectionError,
|
| 18 |
+
RateLimitError,
|
| 19 |
+
Timeout,
|
| 20 |
+
AsyncAzureOpenAI,
|
| 21 |
+
)
|
| 22 |
+
from pydantic import BaseModel, Field
|
| 23 |
+
from tenacity import (
|
| 24 |
+
retry,
|
| 25 |
+
stop_after_attempt,
|
| 26 |
+
wait_exponential,
|
| 27 |
+
retry_if_exception_type,
|
| 28 |
+
)
|
| 29 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 30 |
+
|
| 31 |
+
from .utils import (
|
| 32 |
+
wrap_embedding_func_with_attrs,
|
| 33 |
+
locate_json_string_body_from_string,
|
| 34 |
+
safe_unicode_decode,
|
| 35 |
+
logger,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
import sys
|
| 39 |
+
|
| 40 |
+
if sys.version_info < (3, 9):
|
| 41 |
+
from typing import AsyncIterator
|
| 42 |
+
else:
|
| 43 |
+
from collections.abc import AsyncIterator
|
| 44 |
+
|
| 45 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@retry(
|
| 49 |
+
stop=stop_after_attempt(3),
|
| 50 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 51 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 52 |
+
)
|
| 53 |
+
async def openai_complete_if_cache(
|
| 54 |
+
model,
|
| 55 |
+
prompt,
|
| 56 |
+
system_prompt=None,
|
| 57 |
+
history_messages=[],
|
| 58 |
+
base_url="https://api.openai.com/v1",
|
| 59 |
+
|
| 60 |
+
api_key="",
|
| 61 |
+
**kwargs,
|
| 62 |
+
) -> str:
|
| 63 |
+
if api_key:
|
| 64 |
+
os.environ["OPENAI_API_KEY"] = api_key
|
| 65 |
+
time.sleep(2)
|
| 66 |
+
openai_async_client = (
|
| 67 |
+
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
| 68 |
+
)
|
| 69 |
+
kwargs.pop("hashing_kv", None)
|
| 70 |
+
kwargs.pop("keyword_extraction", None)
|
| 71 |
+
messages = []
|
| 72 |
+
if system_prompt:
|
| 73 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 74 |
+
messages.extend(history_messages)
|
| 75 |
+
messages.append({"role": "user", "content": prompt})
|
| 76 |
+
|
| 77 |
+
logger.debug("===== Query Input to LLM =====")
|
| 78 |
+
logger.debug(f"Query: {prompt}")
|
| 79 |
+
logger.debug(f"System prompt: {system_prompt}")
|
| 80 |
+
logger.debug("Full context:")
|
| 81 |
+
if "response_format" in kwargs:
|
| 82 |
+
response = await openai_async_client.beta.chat.completions.parse(
|
| 83 |
+
model=model, messages=messages, **kwargs
|
| 84 |
+
)
|
| 85 |
+
else:
|
| 86 |
+
response = await openai_async_client.chat.completions.create(
|
| 87 |
+
model=model, messages=messages, **kwargs
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if hasattr(response, "__aiter__"):
|
| 91 |
+
|
| 92 |
+
async def inner():
|
| 93 |
+
async for chunk in response:
|
| 94 |
+
content = chunk.choices[0].delta.content
|
| 95 |
+
if content is None:
|
| 96 |
+
continue
|
| 97 |
+
if r"\u" in content:
|
| 98 |
+
content = safe_unicode_decode(content.encode("utf-8"))
|
| 99 |
+
yield content
|
| 100 |
+
|
| 101 |
+
return inner()
|
| 102 |
+
else:
|
| 103 |
+
content = response.choices[0].message.content
|
| 104 |
+
if r"\u" in content:
|
| 105 |
+
content = safe_unicode_decode(content.encode("utf-8"))
|
| 106 |
+
return content
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@retry(
|
| 110 |
+
stop=stop_after_attempt(3),
|
| 111 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 112 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 113 |
+
)
|
| 114 |
+
async def azure_openai_complete_if_cache(
|
| 115 |
+
model,
|
| 116 |
+
prompt,
|
| 117 |
+
system_prompt=None,
|
| 118 |
+
history_messages=[],
|
| 119 |
+
base_url=None,
|
| 120 |
+
api_key=None,
|
| 121 |
+
api_version=None,
|
| 122 |
+
**kwargs,
|
| 123 |
+
):
|
| 124 |
+
if api_key:
|
| 125 |
+
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
| 126 |
+
if base_url:
|
| 127 |
+
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
| 128 |
+
if api_version:
|
| 129 |
+
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
|
| 130 |
+
|
| 131 |
+
openai_async_client = AsyncAzureOpenAI(
|
| 132 |
+
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
| 133 |
+
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
| 134 |
+
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
| 135 |
+
)
|
| 136 |
+
kwargs.pop("hashing_kv", None)
|
| 137 |
+
messages = []
|
| 138 |
+
if system_prompt:
|
| 139 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 140 |
+
messages.extend(history_messages)
|
| 141 |
+
if prompt is not None:
|
| 142 |
+
messages.append({"role": "user", "content": prompt})
|
| 143 |
+
|
| 144 |
+
response = await openai_async_client.chat.completions.create(
|
| 145 |
+
model=model, messages=messages, **kwargs
|
| 146 |
+
)
|
| 147 |
+
content = response.choices[0].message.content
|
| 148 |
+
|
| 149 |
+
return content
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@retry(
|
| 156 |
+
stop=stop_after_attempt(5),
|
| 157 |
+
wait=wait_exponential(multiplier=1, max=60),
|
| 158 |
+
retry=retry_if_exception_type((BedrockError)),
|
| 159 |
+
)
|
| 160 |
+
async def bedrock_complete_if_cache(
|
| 161 |
+
model,
|
| 162 |
+
prompt,
|
| 163 |
+
system_prompt=None,
|
| 164 |
+
history_messages=[],
|
| 165 |
+
aws_access_key_id=None,
|
| 166 |
+
aws_secret_access_key=None,
|
| 167 |
+
aws_session_token=None,
|
| 168 |
+
**kwargs,
|
| 169 |
+
) -> str:
|
| 170 |
+
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
| 171 |
+
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
| 172 |
+
)
|
| 173 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
| 174 |
+
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
| 175 |
+
)
|
| 176 |
+
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
| 177 |
+
"AWS_SESSION_TOKEN", aws_session_token
|
| 178 |
+
)
|
| 179 |
+
kwargs.pop("hashing_kv", None)
|
| 180 |
+
|
| 181 |
+
messages = []
|
| 182 |
+
for history_message in history_messages:
|
| 183 |
+
message = copy.copy(history_message)
|
| 184 |
+
message["content"] = [{"text": message["content"]}]
|
| 185 |
+
messages.append(message)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
messages.append({"role": "user", "content": [{"text": prompt}]})
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
args = {"modelId": model, "messages": messages}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if system_prompt:
|
| 195 |
+
args["system"] = [{"text": system_prompt}]
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
inference_params_map = {
|
| 199 |
+
"max_tokens": "maxTokens",
|
| 200 |
+
"top_p": "topP",
|
| 201 |
+
"stop_sequences": "stopSequences",
|
| 202 |
+
}
|
| 203 |
+
if inference_params := list(
|
| 204 |
+
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
|
| 205 |
+
):
|
| 206 |
+
args["inferenceConfig"] = {}
|
| 207 |
+
for param in inference_params:
|
| 208 |
+
args["inferenceConfig"][inference_params_map.get(param, param)] = (
|
| 209 |
+
kwargs.pop(param)
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
session = aioboto3.Session()
|
| 214 |
+
async with session.client("bedrock-runtime") as bedrock_async_client:
|
| 215 |
+
try:
|
| 216 |
+
response = await bedrock_async_client.converse(**args, **kwargs)
|
| 217 |
+
except Exception as e:
|
| 218 |
+
raise BedrockError(e)
|
| 219 |
+
|
| 220 |
+
return response["output"]["message"]["content"][0]["text"]
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@lru_cache(maxsize=1)
|
| 224 |
+
def initialize_hf_model(model_name):
|
| 225 |
+
hf_tokenizer = AutoTokenizer.from_pretrained(
|
| 226 |
+
model_name, device_map="auto", trust_remote_code=True
|
| 227 |
+
)
|
| 228 |
+
hf_model = AutoModelForCausalLM.from_pretrained(
|
| 229 |
+
model_name, device_map="auto", trust_remote_code=True
|
| 230 |
+
)
|
| 231 |
+
if hf_tokenizer.pad_token is None:
|
| 232 |
+
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
| 233 |
+
|
| 234 |
+
return hf_model, hf_tokenizer
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
@retry(
|
| 238 |
+
stop=stop_after_attempt(3),
|
| 239 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 240 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 241 |
+
)
|
| 242 |
+
async def hf_model_if_cache(
|
| 243 |
+
model,
|
| 244 |
+
prompt,
|
| 245 |
+
system_prompt=None,
|
| 246 |
+
history_messages=[],
|
| 247 |
+
**kwargs,
|
| 248 |
+
) -> str:
|
| 249 |
+
model_name = model
|
| 250 |
+
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
| 251 |
+
messages = []
|
| 252 |
+
if system_prompt:
|
| 253 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 254 |
+
messages.extend(history_messages)
|
| 255 |
+
messages.append({"role": "user", "content": prompt})
|
| 256 |
+
kwargs.pop("hashing_kv", None)
|
| 257 |
+
input_prompt = ""
|
| 258 |
+
try:
|
| 259 |
+
input_prompt = hf_tokenizer.apply_chat_template(
|
| 260 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 261 |
+
)
|
| 262 |
+
except Exception:
|
| 263 |
+
try:
|
| 264 |
+
ori_message = copy.deepcopy(messages)
|
| 265 |
+
if messages[0]["role"] == "system":
|
| 266 |
+
messages[1]["content"] = (
|
| 267 |
+
"<system>"
|
| 268 |
+
+ messages[0]["content"]
|
| 269 |
+
+ "</system>\n"
|
| 270 |
+
+ messages[1]["content"]
|
| 271 |
+
)
|
| 272 |
+
messages = messages[1:]
|
| 273 |
+
input_prompt = hf_tokenizer.apply_chat_template(
|
| 274 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 275 |
+
)
|
| 276 |
+
except Exception:
|
| 277 |
+
len_message = len(ori_message)
|
| 278 |
+
for msgid in range(len_message):
|
| 279 |
+
input_prompt = (
|
| 280 |
+
input_prompt
|
| 281 |
+
+ "<"
|
| 282 |
+
+ ori_message[msgid]["role"]
|
| 283 |
+
+ ">"
|
| 284 |
+
+ ori_message[msgid]["content"]
|
| 285 |
+
+ "</"
|
| 286 |
+
+ ori_message[msgid]["role"]
|
| 287 |
+
+ ">\n"
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
input_ids = hf_tokenizer(
|
| 291 |
+
input_prompt, return_tensors="pt", padding=True, truncation=True
|
| 292 |
+
).to("cuda")
|
| 293 |
+
inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
|
| 294 |
+
output = hf_model.generate(
|
| 295 |
+
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
|
| 296 |
+
)
|
| 297 |
+
response_text = hf_tokenizer.decode(
|
| 298 |
+
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
return response_text
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@retry(
|
| 305 |
+
stop=stop_after_attempt(3),
|
| 306 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 307 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 308 |
+
)
|
| 309 |
+
async def ollama_model_if_cache(
|
| 310 |
+
model,
|
| 311 |
+
prompt,
|
| 312 |
+
system_prompt=None,
|
| 313 |
+
history_messages=[],
|
| 314 |
+
**kwargs,
|
| 315 |
+
) -> Union[str, AsyncIterator[str]]:
|
| 316 |
+
stream = True if kwargs.get("stream") else False
|
| 317 |
+
kwargs.pop("max_tokens", None)
|
| 318 |
+
|
| 319 |
+
host = kwargs.pop("host", None)
|
| 320 |
+
timeout = kwargs.pop("timeout", None)
|
| 321 |
+
kwargs.pop("hashing_kv", None)
|
| 322 |
+
ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
|
| 323 |
+
messages = []
|
| 324 |
+
if system_prompt:
|
| 325 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 326 |
+
messages.extend(history_messages)
|
| 327 |
+
messages.append({"role": "user", "content": prompt})
|
| 328 |
+
|
| 329 |
+
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
| 330 |
+
if stream:
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
async def inner():
|
| 334 |
+
async for chunk in response:
|
| 335 |
+
yield chunk["message"]["content"]
|
| 336 |
+
|
| 337 |
+
return inner()
|
| 338 |
+
else:
|
| 339 |
+
return response["message"]["content"]
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
@lru_cache(maxsize=1)
|
| 343 |
+
def initialize_lmdeploy_pipeline(
|
| 344 |
+
model,
|
| 345 |
+
tp=1,
|
| 346 |
+
chat_template=None,
|
| 347 |
+
log_level="WARNING",
|
| 348 |
+
model_format="hf",
|
| 349 |
+
quant_policy=0,
|
| 350 |
+
):
|
| 351 |
+
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
|
| 352 |
+
|
| 353 |
+
lmdeploy_pipe = pipeline(
|
| 354 |
+
model_path=model,
|
| 355 |
+
backend_config=TurbomindEngineConfig(
|
| 356 |
+
tp=tp, model_format=model_format, quant_policy=quant_policy
|
| 357 |
+
),
|
| 358 |
+
chat_template_config=(
|
| 359 |
+
ChatTemplateConfig(model_name=chat_template) if chat_template else None
|
| 360 |
+
),
|
| 361 |
+
log_level="WARNING",
|
| 362 |
+
)
|
| 363 |
+
return lmdeploy_pipe
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
@retry(
|
| 367 |
+
stop=stop_after_attempt(3),
|
| 368 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 369 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 370 |
+
)
|
| 371 |
+
async def lmdeploy_model_if_cache(
|
| 372 |
+
model,
|
| 373 |
+
prompt,
|
| 374 |
+
system_prompt=None,
|
| 375 |
+
history_messages=[],
|
| 376 |
+
chat_template=None,
|
| 377 |
+
model_format="hf",
|
| 378 |
+
quant_policy=0,
|
| 379 |
+
**kwargs,
|
| 380 |
+
) -> str:
|
| 381 |
+
|
| 382 |
+
try:
|
| 383 |
+
import lmdeploy
|
| 384 |
+
from lmdeploy import version_info, GenerationConfig
|
| 385 |
+
except Exception:
|
| 386 |
+
raise ImportError("Please install lmdeploy before initialize lmdeploy backend.")
|
| 387 |
+
kwargs.pop("hashing_kv", None)
|
| 388 |
+
kwargs.pop("response_format", None)
|
| 389 |
+
max_new_tokens = kwargs.pop("max_tokens", 512)
|
| 390 |
+
tp = kwargs.pop("tp", 1)
|
| 391 |
+
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
|
| 392 |
+
do_preprocess = kwargs.pop("do_preprocess", True)
|
| 393 |
+
do_sample = kwargs.pop("do_sample", False)
|
| 394 |
+
gen_params = kwargs
|
| 395 |
+
|
| 396 |
+
version = version_info
|
| 397 |
+
if do_sample is not None and version < (0, 6, 0):
|
| 398 |
+
raise RuntimeError(
|
| 399 |
+
"`do_sample` parameter is not supported by lmdeploy until "
|
| 400 |
+
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
|
| 401 |
+
)
|
| 402 |
+
else:
|
| 403 |
+
do_sample = True
|
| 404 |
+
gen_params.update(do_sample=do_sample)
|
| 405 |
+
|
| 406 |
+
lmdeploy_pipe = initialize_lmdeploy_pipeline(
|
| 407 |
+
model=model,
|
| 408 |
+
tp=tp,
|
| 409 |
+
chat_template=chat_template,
|
| 410 |
+
model_format=model_format,
|
| 411 |
+
quant_policy=quant_policy,
|
| 412 |
+
log_level="WARNING",
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
messages = []
|
| 416 |
+
if system_prompt:
|
| 417 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 418 |
+
|
| 419 |
+
messages.extend(history_messages)
|
| 420 |
+
messages.append({"role": "user", "content": prompt})
|
| 421 |
+
|
| 422 |
+
gen_config = GenerationConfig(
|
| 423 |
+
skip_special_tokens=skip_special_tokens,
|
| 424 |
+
max_new_tokens=max_new_tokens,
|
| 425 |
+
**gen_params,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
response = ""
|
| 429 |
+
async for res in lmdeploy_pipe.generate(
|
| 430 |
+
messages,
|
| 431 |
+
gen_config=gen_config,
|
| 432 |
+
do_preprocess=do_preprocess,
|
| 433 |
+
stream_response=False,
|
| 434 |
+
session_id=1,
|
| 435 |
+
):
|
| 436 |
+
response += res.response
|
| 437 |
+
return response
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class GPTKeywordExtractionFormat(BaseModel):
|
| 441 |
+
high_level_keywords: List[str]
|
| 442 |
+
low_level_keywords: List[str]
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
async def openai_complete(
|
| 446 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 447 |
+
) -> Union[str, AsyncIterator[str]]:
|
| 448 |
+
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
| 449 |
+
if keyword_extraction:
|
| 450 |
+
kwargs["response_format"] = "json"
|
| 451 |
+
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
| 452 |
+
return await openai_complete_if_cache(
|
| 453 |
+
model_name,
|
| 454 |
+
prompt,
|
| 455 |
+
system_prompt=system_prompt,
|
| 456 |
+
history_messages=history_messages,
|
| 457 |
+
**kwargs,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
async def gpt_4o_complete(
|
| 462 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 463 |
+
) -> str:
|
| 464 |
+
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
| 465 |
+
if keyword_extraction:
|
| 466 |
+
kwargs["response_format"] = GPTKeywordExtractionFormat
|
| 467 |
+
return await openai_complete_if_cache(
|
| 468 |
+
"gpt-4o",
|
| 469 |
+
prompt,
|
| 470 |
+
system_prompt=system_prompt,
|
| 471 |
+
history_messages=history_messages,
|
| 472 |
+
**kwargs,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
async def gpt_4o_mini_complete(
|
| 477 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 478 |
+
) -> str:
|
| 479 |
+
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
| 480 |
+
if keyword_extraction:
|
| 481 |
+
kwargs["response_format"] = GPTKeywordExtractionFormat
|
| 482 |
+
return await openai_complete_if_cache(
|
| 483 |
+
"gpt-4o-mini",
|
| 484 |
+
prompt,
|
| 485 |
+
system_prompt=system_prompt,
|
| 486 |
+
history_messages=history_messages,
|
| 487 |
+
**kwargs,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
async def nvidia_openai_complete(
|
| 492 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 493 |
+
) -> str:
|
| 494 |
+
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
| 495 |
+
result = await openai_complete_if_cache(
|
| 496 |
+
"nvidia/llama-3.1-nemotron-70b-instruct",
|
| 497 |
+
prompt,
|
| 498 |
+
system_prompt=system_prompt,
|
| 499 |
+
history_messages=history_messages,
|
| 500 |
+
base_url="https://integrate.api.nvidia.com/v1",
|
| 501 |
+
**kwargs,
|
| 502 |
+
)
|
| 503 |
+
if keyword_extraction: # TODO: use JSON API
|
| 504 |
+
return locate_json_string_body_from_string(result)
|
| 505 |
+
return result
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
async def azure_openai_complete(
|
| 509 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 510 |
+
) -> str:
|
| 511 |
+
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
| 512 |
+
result = await azure_openai_complete_if_cache(
|
| 513 |
+
"conversation-4o-mini",
|
| 514 |
+
prompt,
|
| 515 |
+
system_prompt=system_prompt,
|
| 516 |
+
history_messages=history_messages,
|
| 517 |
+
**kwargs,
|
| 518 |
+
)
|
| 519 |
+
if keyword_extraction: # TODO: use JSON API
|
| 520 |
+
return locate_json_string_body_from_string(result)
|
| 521 |
+
return result
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
async def bedrock_complete(
|
| 525 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 526 |
+
) -> str:
|
| 527 |
+
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
| 528 |
+
result = await bedrock_complete_if_cache(
|
| 529 |
+
"anthropic.claude-3-haiku-20240307-v1:0",
|
| 530 |
+
prompt,
|
| 531 |
+
system_prompt=system_prompt,
|
| 532 |
+
history_messages=history_messages,
|
| 533 |
+
**kwargs,
|
| 534 |
+
)
|
| 535 |
+
if keyword_extraction: # TODO: use JSON API
|
| 536 |
+
return locate_json_string_body_from_string(result)
|
| 537 |
+
return result
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
async def hf_model_complete(
|
| 541 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 542 |
+
) -> str:
|
| 543 |
+
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
| 544 |
+
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
| 545 |
+
result = await hf_model_if_cache(
|
| 546 |
+
model_name,
|
| 547 |
+
prompt,
|
| 548 |
+
system_prompt=system_prompt,
|
| 549 |
+
history_messages=history_messages,
|
| 550 |
+
**kwargs,
|
| 551 |
+
)
|
| 552 |
+
if keyword_extraction: # TODO: use JSON API
|
| 553 |
+
return locate_json_string_body_from_string(result)
|
| 554 |
+
return result
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
async def ollama_model_complete(
|
| 558 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 559 |
+
) -> Union[str, AsyncIterator[str]]:
|
| 560 |
+
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
| 561 |
+
if keyword_extraction:
|
| 562 |
+
kwargs["format"] = "json"
|
| 563 |
+
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
| 564 |
+
return await ollama_model_if_cache(
|
| 565 |
+
model_name,
|
| 566 |
+
prompt,
|
| 567 |
+
system_prompt=system_prompt,
|
| 568 |
+
history_messages=history_messages,
|
| 569 |
+
**kwargs,
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
@retry(
|
| 574 |
+
stop=stop_after_attempt(3),
|
| 575 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 576 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 577 |
+
)
|
| 578 |
+
async def zhipu_complete_if_cache(
|
| 579 |
+
prompt: Union[str, List[Dict[str, str]]],
|
| 580 |
+
model: str = "glm-4-flashx",
|
| 581 |
+
api_key: Optional[str] = None,
|
| 582 |
+
system_prompt: Optional[str] = None,
|
| 583 |
+
history_messages: List[Dict[str, str]] = [],
|
| 584 |
+
**kwargs,
|
| 585 |
+
) -> str:
|
| 586 |
+
|
| 587 |
+
try:
|
| 588 |
+
from zhipuai import ZhipuAI
|
| 589 |
+
except ImportError:
|
| 590 |
+
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
|
| 591 |
+
|
| 592 |
+
if api_key:
|
| 593 |
+
client = ZhipuAI(api_key=api_key)
|
| 594 |
+
else:
|
| 595 |
+
|
| 596 |
+
client = ZhipuAI()
|
| 597 |
+
|
| 598 |
+
messages = []
|
| 599 |
+
|
| 600 |
+
if not system_prompt:
|
| 601 |
+
system_prompt = "You are a helpful assistant. 注意内容里的敏感词用***替换。"
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
if system_prompt:
|
| 605 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 606 |
+
messages.extend(history_messages)
|
| 607 |
+
messages.append({"role": "user", "content": prompt})
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
logger.debug("===== Query Input to LLM =====")
|
| 611 |
+
logger.debug(f"Query: {prompt}")
|
| 612 |
+
logger.debug(f"System prompt: {system_prompt}")
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
kwargs = {
|
| 616 |
+
k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
response = client.chat.completions.create(model=model, messages=messages, **kwargs)
|
| 620 |
+
|
| 621 |
+
return response.choices[0].message.content
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
async def zhipu_complete(
|
| 625 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
| 626 |
+
):
|
| 627 |
+
|
| 628 |
+
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
| 629 |
+
|
| 630 |
+
if keyword_extraction:
|
| 631 |
+
|
| 632 |
+
extraction_prompt = """You are a helpful assistant that extracts keywords from text.
|
| 633 |
+
Please analyze the content and extract two types of keywords:
|
| 634 |
+
1. High-level keywords: Important concepts and main themes
|
| 635 |
+
2. Low-level keywords: Specific details and supporting elements
|
| 636 |
+
|
| 637 |
+
Return your response in this exact JSON format:
|
| 638 |
+
{
|
| 639 |
+
"high_level_keywords": ["keyword1", "keyword2"],
|
| 640 |
+
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
Only return the JSON, no other text."""
|
| 644 |
+
|
| 645 |
+
if system_prompt:
|
| 646 |
+
system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
|
| 647 |
+
else:
|
| 648 |
+
system_prompt = extraction_prompt
|
| 649 |
+
|
| 650 |
+
try:
|
| 651 |
+
response = await zhipu_complete_if_cache(
|
| 652 |
+
prompt=prompt,
|
| 653 |
+
system_prompt=system_prompt,
|
| 654 |
+
history_messages=history_messages,
|
| 655 |
+
**kwargs,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
try:
|
| 660 |
+
data = json.loads(response)
|
| 661 |
+
return GPTKeywordExtractionFormat(
|
| 662 |
+
high_level_keywords=data.get("high_level_keywords", []),
|
| 663 |
+
low_level_keywords=data.get("low_level_keywords", []),
|
| 664 |
+
)
|
| 665 |
+
except json.JSONDecodeError:
|
| 666 |
+
match = re.search(r"\{[\s\S]*\}", response)
|
| 667 |
+
if match:
|
| 668 |
+
try:
|
| 669 |
+
data = json.loads(match.group())
|
| 670 |
+
return GPTKeywordExtractionFormat(
|
| 671 |
+
high_level_keywords=data.get("high_level_keywords", []),
|
| 672 |
+
low_level_keywords=data.get("low_level_keywords", []),
|
| 673 |
+
)
|
| 674 |
+
except json.JSONDecodeError:
|
| 675 |
+
pass
|
| 676 |
+
|
| 677 |
+
logger.warning(
|
| 678 |
+
f"Failed to parse keyword extraction response: {response}"
|
| 679 |
+
)
|
| 680 |
+
return GPTKeywordExtractionFormat(
|
| 681 |
+
high_level_keywords=[], low_level_keywords=[]
|
| 682 |
+
)
|
| 683 |
+
except Exception as e:
|
| 684 |
+
logger.error(f"Error during keyword extraction: {str(e)}")
|
| 685 |
+
return GPTKeywordExtractionFormat(
|
| 686 |
+
high_level_keywords=[], low_level_keywords=[]
|
| 687 |
+
)
|
| 688 |
+
else:
|
| 689 |
+
|
| 690 |
+
return await zhipu_complete_if_cache(
|
| 691 |
+
prompt=prompt,
|
| 692 |
+
system_prompt=system_prompt,
|
| 693 |
+
history_messages=history_messages,
|
| 694 |
+
**kwargs,
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
| 699 |
+
@retry(
|
| 700 |
+
stop=stop_after_attempt(3),
|
| 701 |
+
wait=wait_exponential(multiplier=1, min=4, max=60),
|
| 702 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 703 |
+
)
|
| 704 |
+
async def zhipu_embedding(
|
| 705 |
+
texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
|
| 706 |
+
) -> np.ndarray:
|
| 707 |
+
|
| 708 |
+
try:
|
| 709 |
+
from zhipuai import ZhipuAI
|
| 710 |
+
except ImportError:
|
| 711 |
+
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
|
| 712 |
+
if api_key:
|
| 713 |
+
client = ZhipuAI(api_key=api_key)
|
| 714 |
+
else:
|
| 715 |
+
|
| 716 |
+
client = ZhipuAI()
|
| 717 |
+
|
| 718 |
+
if isinstance(texts, str):
|
| 719 |
+
texts = [texts]
|
| 720 |
+
|
| 721 |
+
embeddings = []
|
| 722 |
+
for text in texts:
|
| 723 |
+
try:
|
| 724 |
+
response = client.embeddings.create(model=model, input=[text], **kwargs)
|
| 725 |
+
embeddings.append(response.data[0].embedding)
|
| 726 |
+
except Exception as e:
|
| 727 |
+
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
|
| 728 |
+
|
| 729 |
+
return np.array(embeddings)
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
| 733 |
+
@retry(
|
| 734 |
+
stop=stop_after_attempt(3),
|
| 735 |
+
wait=wait_exponential(multiplier=1, min=4, max=60),
|
| 736 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 737 |
+
)
|
| 738 |
+
async def openai_embedding(
|
| 739 |
+
texts: list[str],
|
| 740 |
+
model: str = "text-embedding-3-small",
|
| 741 |
+
base_url="https://api.openai.com/v1",
|
| 742 |
+
api_key="",
|
| 743 |
+
) -> np.ndarray:
|
| 744 |
+
if api_key:
|
| 745 |
+
os.environ["OPENAI_API_KEY"] = api_key
|
| 746 |
+
|
| 747 |
+
openai_async_client = (
|
| 748 |
+
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
| 749 |
+
)
|
| 750 |
+
response = await openai_async_client.embeddings.create(
|
| 751 |
+
model=model, input=texts, encoding_format="float"
|
| 752 |
+
)
|
| 753 |
+
return np.array([dp.embedding for dp in response.data])
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
async def fetch_data(url, headers, data):
|
| 757 |
+
async with aiohttp.ClientSession() as session:
|
| 758 |
+
async with session.post(url, headers=headers, json=data) as response:
|
| 759 |
+
response_json = await response.json()
|
| 760 |
+
data_list = response_json.get("data", [])
|
| 761 |
+
return data_list
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
async def jina_embedding(
|
| 765 |
+
texts: list[str],
|
| 766 |
+
dimensions: int = 1024,
|
| 767 |
+
late_chunking: bool = False,
|
| 768 |
+
base_url: str = None,
|
| 769 |
+
api_key: str = None,
|
| 770 |
+
) -> np.ndarray:
|
| 771 |
+
if api_key:
|
| 772 |
+
os.environ["JINA_API_KEY"] = api_key
|
| 773 |
+
url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
|
| 774 |
+
headers = {
|
| 775 |
+
"Content-Type": "application/json",
|
| 776 |
+
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
|
| 777 |
+
}
|
| 778 |
+
data = {
|
| 779 |
+
"model": "jina-embeddings-v3",
|
| 780 |
+
"normalized": True,
|
| 781 |
+
"embedding_type": "float",
|
| 782 |
+
"dimensions": f"{dimensions}",
|
| 783 |
+
"late_chunking": late_chunking,
|
| 784 |
+
"input": texts,
|
| 785 |
+
}
|
| 786 |
+
data_list = await fetch_data(url, headers, data)
|
| 787 |
+
return np.array([dp["embedding"] for dp in data_list])
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
|
| 791 |
+
@retry(
|
| 792 |
+
stop=stop_after_attempt(3),
|
| 793 |
+
wait=wait_exponential(multiplier=1, min=4, max=60),
|
| 794 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 795 |
+
)
|
| 796 |
+
async def nvidia_openai_embedding(
|
| 797 |
+
texts: list[str],
|
| 798 |
+
model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
|
| 799 |
+
base_url: str = "https://integrate.api.nvidia.com/v1",
|
| 800 |
+
api_key: str = None,
|
| 801 |
+
input_type: str = "passage",
|
| 802 |
+
trunc: str = "NONE",
|
| 803 |
+
encode: str = "float",
|
| 804 |
+
) -> np.ndarray:
|
| 805 |
+
if api_key:
|
| 806 |
+
os.environ["OPENAI_API_KEY"] = api_key
|
| 807 |
+
|
| 808 |
+
openai_async_client = (
|
| 809 |
+
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
| 810 |
+
)
|
| 811 |
+
response = await openai_async_client.embeddings.create(
|
| 812 |
+
model=model,
|
| 813 |
+
input=texts,
|
| 814 |
+
encoding_format=encode,
|
| 815 |
+
extra_body={"input_type": input_type, "truncate": trunc},
|
| 816 |
+
)
|
| 817 |
+
return np.array([dp.embedding for dp in response.data])
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
|
| 821 |
+
@retry(
|
| 822 |
+
stop=stop_after_attempt(3),
|
| 823 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 824 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 825 |
+
)
|
| 826 |
+
async def azure_openai_embedding(
|
| 827 |
+
texts: list[str],
|
| 828 |
+
model: str = "text-embedding-3-small",
|
| 829 |
+
base_url: str = None,
|
| 830 |
+
api_key: str = None,
|
| 831 |
+
api_version: str = None,
|
| 832 |
+
) -> np.ndarray:
|
| 833 |
+
if api_key:
|
| 834 |
+
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
| 835 |
+
if base_url:
|
| 836 |
+
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
| 837 |
+
if api_version:
|
| 838 |
+
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
|
| 839 |
+
|
| 840 |
+
openai_async_client = AsyncAzureOpenAI(
|
| 841 |
+
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
| 842 |
+
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
| 843 |
+
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
response = await openai_async_client.embeddings.create(
|
| 847 |
+
model=model, input=texts, encoding_format="float"
|
| 848 |
+
)
|
| 849 |
+
return np.array([dp.embedding for dp in response.data])
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
@retry(
|
| 853 |
+
stop=stop_after_attempt(3),
|
| 854 |
+
wait=wait_exponential(multiplier=1, min=4, max=60),
|
| 855 |
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
| 856 |
+
)
|
| 857 |
+
async def siliconcloud_embedding(
|
| 858 |
+
texts: list[str],
|
| 859 |
+
model: str = "netease-youdao/bce-embedding-base_v1",
|
| 860 |
+
base_url: str = "https://api.siliconflow.cn/v1/embeddings",
|
| 861 |
+
max_token_size: int = 512,
|
| 862 |
+
api_key: str = None,
|
| 863 |
+
) -> np.ndarray:
|
| 864 |
+
if api_key and not api_key.startswith("Bearer "):
|
| 865 |
+
api_key = "Bearer " + api_key
|
| 866 |
+
|
| 867 |
+
headers = {"Authorization": api_key, "Content-Type": "application/json"}
|
| 868 |
+
|
| 869 |
+
truncate_texts = [text[0:max_token_size] for text in texts]
|
| 870 |
+
|
| 871 |
+
payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
|
| 872 |
+
|
| 873 |
+
base64_strings = []
|
| 874 |
+
async with aiohttp.ClientSession() as session:
|
| 875 |
+
async with session.post(base_url, headers=headers, json=payload) as response:
|
| 876 |
+
content = await response.json()
|
| 877 |
+
if "code" in content:
|
| 878 |
+
raise ValueError(content)
|
| 879 |
+
base64_strings = [item["embedding"] for item in content["data"]]
|
| 880 |
+
|
| 881 |
+
embeddings = []
|
| 882 |
+
for string in base64_strings:
|
| 883 |
+
decode_bytes = base64.b64decode(string)
|
| 884 |
+
n = len(decode_bytes) // 4
|
| 885 |
+
float_array = struct.unpack("<" + "f" * n, decode_bytes)
|
| 886 |
+
embeddings.append(float_array)
|
| 887 |
+
return np.array(embeddings)
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
async def bedrock_embedding(
|
| 891 |
+
texts: list[str],
|
| 892 |
+
model: str = "amazon.titan-embed-text-v2:0",
|
| 893 |
+
aws_access_key_id=None,
|
| 894 |
+
aws_secret_access_key=None,
|
| 895 |
+
aws_session_token=None,
|
| 896 |
+
) -> np.ndarray:
|
| 897 |
+
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
| 898 |
+
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
| 899 |
+
)
|
| 900 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
| 901 |
+
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
| 902 |
+
)
|
| 903 |
+
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
| 904 |
+
"AWS_SESSION_TOKEN", aws_session_token
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
session = aioboto3.Session()
|
| 908 |
+
async with session.client("bedrock-runtime") as bedrock_async_client:
|
| 909 |
+
if (model_provider := model.split(".")[0]) == "amazon":
|
| 910 |
+
embed_texts = []
|
| 911 |
+
for text in texts:
|
| 912 |
+
if "v2" in model:
|
| 913 |
+
body = json.dumps(
|
| 914 |
+
{
|
| 915 |
+
"inputText": text,
|
| 916 |
+
|
| 917 |
+
"embeddingTypes": ["float"],
|
| 918 |
+
}
|
| 919 |
+
)
|
| 920 |
+
elif "v1" in model:
|
| 921 |
+
body = json.dumps({"inputText": text})
|
| 922 |
+
else:
|
| 923 |
+
raise ValueError(f"Model {model} is not supported!")
|
| 924 |
+
|
| 925 |
+
response = await bedrock_async_client.invoke_model(
|
| 926 |
+
modelId=model,
|
| 927 |
+
body=body,
|
| 928 |
+
accept="application/json",
|
| 929 |
+
contentType="application/json",
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
response_body = await response.get("body").json()
|
| 933 |
+
|
| 934 |
+
embed_texts.append(response_body["embedding"])
|
| 935 |
+
elif model_provider == "cohere":
|
| 936 |
+
body = json.dumps(
|
| 937 |
+
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
response = await bedrock_async_client.invoke_model(
|
| 941 |
+
model=model,
|
| 942 |
+
body=body,
|
| 943 |
+
accept="application/json",
|
| 944 |
+
contentType="application/json",
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
response_body = json.loads(response.get("body").read())
|
| 948 |
+
|
| 949 |
+
embed_texts = response_body["embeddings"]
|
| 950 |
+
else:
|
| 951 |
+
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
| 952 |
+
|
| 953 |
+
return np.array(embed_texts)
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
| 957 |
+
device = next(embed_model.parameters()).device
|
| 958 |
+
input_ids = tokenizer(
|
| 959 |
+
texts, return_tensors="pt", padding=True, truncation=True
|
| 960 |
+
).input_ids.to(device)
|
| 961 |
+
with torch.no_grad():
|
| 962 |
+
outputs = embed_model(input_ids)
|
| 963 |
+
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 964 |
+
if embeddings.dtype == torch.bfloat16:
|
| 965 |
+
return embeddings.detach().to(torch.float32).cpu().numpy()
|
| 966 |
+
else:
|
| 967 |
+
return embeddings.detach().cpu().numpy()
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
| 971 |
+
|
| 972 |
+
embed_text = []
|
| 973 |
+
ollama_client = ollama.Client(**kwargs)
|
| 974 |
+
for text in texts:
|
| 975 |
+
data = ollama_client.embeddings(model=embed_model, prompt=text)
|
| 976 |
+
embed_text.append(data["embedding"])
|
| 977 |
+
|
| 978 |
+
return embed_text
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
| 982 |
+
ollama_client = ollama.Client(**kwargs)
|
| 983 |
+
data = ollama_client.embed(model=embed_model, input=texts)
|
| 984 |
+
return data["embeddings"]
|
| 985 |
+
|
| 986 |
+
|
| 987 |
+
class Model(BaseModel):
|
| 988 |
+
|
| 989 |
+
gen_func: Callable[[Any], str] = Field(
|
| 990 |
+
...,
|
| 991 |
+
description="A function that generates the response from the llm. The response must be a string",
|
| 992 |
+
)
|
| 993 |
+
kwargs: Dict[str, Any] = Field(
|
| 994 |
+
...,
|
| 995 |
+
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
class Config:
|
| 999 |
+
arbitrary_types_allowed = True
|
| 1000 |
+
|
| 1001 |
+
|
| 1002 |
+
class MultiModel:
|
| 1003 |
+
def __init__(self, models: List[Model]):
|
| 1004 |
+
self._models = models
|
| 1005 |
+
self._current_model = 0
|
| 1006 |
+
|
| 1007 |
+
def _next_model(self):
|
| 1008 |
+
self._current_model = (self._current_model + 1) % len(self._models)
|
| 1009 |
+
return self._models[self._current_model]
|
| 1010 |
+
|
| 1011 |
+
async def llm_model_func(
|
| 1012 |
+
self, prompt, system_prompt=None, history_messages=[], **kwargs
|
| 1013 |
+
) -> str:
|
| 1014 |
+
kwargs.pop("model", None)
|
| 1015 |
+
kwargs.pop("keyword_extraction", None)
|
| 1016 |
+
kwargs.pop("mode", None)
|
| 1017 |
+
next_model = self._next_model()
|
| 1018 |
+
args = dict(
|
| 1019 |
+
prompt=prompt,
|
| 1020 |
+
system_prompt=system_prompt,
|
| 1021 |
+
history_messages=history_messages,
|
| 1022 |
+
**kwargs,
|
| 1023 |
+
**next_model.kwargs,
|
| 1024 |
+
)
|
| 1025 |
+
|
| 1026 |
+
return await next_model.gen_func(**args)
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
if __name__ == "__main__":
|
| 1030 |
+
import asyncio
|
| 1031 |
+
|
| 1032 |
+
async def main():
|
| 1033 |
+
result = await gpt_4o_mini_complete("How are you?")
|
| 1034 |
+
print(result)
|
| 1035 |
+
|
| 1036 |
+
asyncio.run(main())
|
PathRAG/operate.py
ADDED
|
@@ -0,0 +1,1239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
| 5 |
+
from typing import Union
|
| 6 |
+
from collections import Counter, defaultdict
|
| 7 |
+
import warnings
|
| 8 |
+
import tiktoken
|
| 9 |
+
import time
|
| 10 |
+
import csv
|
| 11 |
+
from .utils import (
|
| 12 |
+
logger,
|
| 13 |
+
clean_str,
|
| 14 |
+
compute_mdhash_id,
|
| 15 |
+
decode_tokens_by_tiktoken,
|
| 16 |
+
encode_string_by_tiktoken,
|
| 17 |
+
is_float_regex,
|
| 18 |
+
list_of_list_to_csv,
|
| 19 |
+
pack_user_ass_to_openai_messages,
|
| 20 |
+
split_string_by_multi_markers,
|
| 21 |
+
truncate_list_by_token_size,
|
| 22 |
+
process_combine_contexts,
|
| 23 |
+
compute_args_hash,
|
| 24 |
+
handle_cache,
|
| 25 |
+
save_to_cache,
|
| 26 |
+
CacheData,
|
| 27 |
+
)
|
| 28 |
+
from .base import (
|
| 29 |
+
BaseGraphStorage,
|
| 30 |
+
BaseKVStorage,
|
| 31 |
+
BaseVectorStorage,
|
| 32 |
+
TextChunkSchema,
|
| 33 |
+
QueryParam,
|
| 34 |
+
)
|
| 35 |
+
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def chunking_by_token_size(
|
| 39 |
+
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
| 40 |
+
):
|
| 41 |
+
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
|
| 42 |
+
results = []
|
| 43 |
+
for index, start in enumerate(
|
| 44 |
+
range(0, len(tokens), max_token_size - overlap_token_size)
|
| 45 |
+
):
|
| 46 |
+
chunk_content = decode_tokens_by_tiktoken(
|
| 47 |
+
tokens[start : start + max_token_size], model_name=tiktoken_model
|
| 48 |
+
)
|
| 49 |
+
results.append(
|
| 50 |
+
{
|
| 51 |
+
"tokens": min(max_token_size, len(tokens) - start),
|
| 52 |
+
"content": chunk_content.strip(),
|
| 53 |
+
"chunk_order_index": index,
|
| 54 |
+
}
|
| 55 |
+
)
|
| 56 |
+
return results
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
async def _handle_entity_relation_summary(
|
| 60 |
+
entity_or_relation_name: str,
|
| 61 |
+
description: str,
|
| 62 |
+
global_config: dict,
|
| 63 |
+
) -> str:
|
| 64 |
+
use_llm_func: callable = global_config["llm_model_func"]
|
| 65 |
+
llm_max_tokens = global_config["llm_model_max_token_size"]
|
| 66 |
+
tiktoken_model_name = global_config["tiktoken_model_name"]
|
| 67 |
+
summary_max_tokens = global_config["entity_summary_to_max_tokens"]
|
| 68 |
+
language = global_config["addon_params"].get(
|
| 69 |
+
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
|
| 73 |
+
if len(tokens) < summary_max_tokens:
|
| 74 |
+
return description
|
| 75 |
+
prompt_template = PROMPTS["summarize_entity_descriptions"]
|
| 76 |
+
use_description = decode_tokens_by_tiktoken(
|
| 77 |
+
tokens[:llm_max_tokens], model_name=tiktoken_model_name
|
| 78 |
+
)
|
| 79 |
+
context_base = dict(
|
| 80 |
+
entity_name=entity_or_relation_name,
|
| 81 |
+
description_list=use_description.split(GRAPH_FIELD_SEP),
|
| 82 |
+
language=language,
|
| 83 |
+
)
|
| 84 |
+
use_prompt = prompt_template.format(**context_base)
|
| 85 |
+
logger.debug(f"Trigger summary: {entity_or_relation_name}")
|
| 86 |
+
summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
|
| 87 |
+
return summary
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
async def _handle_single_entity_extraction(
|
| 91 |
+
record_attributes: list[str],
|
| 92 |
+
chunk_key: str,
|
| 93 |
+
):
|
| 94 |
+
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
entity_name = clean_str(record_attributes[1].upper())
|
| 98 |
+
if not entity_name.strip():
|
| 99 |
+
return None
|
| 100 |
+
entity_type = clean_str(record_attributes[2].upper())
|
| 101 |
+
entity_description = clean_str(record_attributes[3])
|
| 102 |
+
entity_source_id = chunk_key
|
| 103 |
+
return dict(
|
| 104 |
+
entity_name=entity_name,
|
| 105 |
+
entity_type=entity_type,
|
| 106 |
+
description=entity_description,
|
| 107 |
+
source_id=entity_source_id,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
async def _handle_single_relationship_extraction(
|
| 112 |
+
record_attributes: list[str],
|
| 113 |
+
chunk_key: str,
|
| 114 |
+
):
|
| 115 |
+
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
source = clean_str(record_attributes[1].upper())
|
| 119 |
+
target = clean_str(record_attributes[2].upper())
|
| 120 |
+
edge_description = clean_str(record_attributes[3])
|
| 121 |
+
|
| 122 |
+
edge_keywords = clean_str(record_attributes[4])
|
| 123 |
+
edge_source_id = chunk_key
|
| 124 |
+
weight = (
|
| 125 |
+
float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
|
| 126 |
+
)
|
| 127 |
+
return dict(
|
| 128 |
+
src_id=source,
|
| 129 |
+
tgt_id=target,
|
| 130 |
+
weight=weight,
|
| 131 |
+
description=edge_description,
|
| 132 |
+
keywords=edge_keywords,
|
| 133 |
+
source_id=edge_source_id,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
async def _merge_nodes_then_upsert(
|
| 138 |
+
entity_name: str,
|
| 139 |
+
nodes_data: list[dict],
|
| 140 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 141 |
+
global_config: dict,
|
| 142 |
+
):
|
| 143 |
+
already_entity_types = []
|
| 144 |
+
already_source_ids = []
|
| 145 |
+
already_description = []
|
| 146 |
+
|
| 147 |
+
already_node = await knowledge_graph_inst.get_node(entity_name)
|
| 148 |
+
if already_node is not None:
|
| 149 |
+
already_entity_types.append(already_node["entity_type"])
|
| 150 |
+
already_source_ids.extend(
|
| 151 |
+
split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
|
| 152 |
+
)
|
| 153 |
+
already_description.append(already_node["description"])
|
| 154 |
+
|
| 155 |
+
entity_type = sorted(
|
| 156 |
+
Counter(
|
| 157 |
+
[dp["entity_type"] for dp in nodes_data] + already_entity_types
|
| 158 |
+
).items(),
|
| 159 |
+
key=lambda x: x[1],
|
| 160 |
+
reverse=True,
|
| 161 |
+
)[0][0]
|
| 162 |
+
description = GRAPH_FIELD_SEP.join(
|
| 163 |
+
sorted(set([dp["description"] for dp in nodes_data] + already_description))
|
| 164 |
+
)
|
| 165 |
+
source_id = GRAPH_FIELD_SEP.join(
|
| 166 |
+
set([dp["source_id"] for dp in nodes_data] + already_source_ids)
|
| 167 |
+
)
|
| 168 |
+
description = await _handle_entity_relation_summary(
|
| 169 |
+
entity_name, description, global_config
|
| 170 |
+
)
|
| 171 |
+
node_data = dict(
|
| 172 |
+
entity_type=entity_type,
|
| 173 |
+
description=description,
|
| 174 |
+
source_id=source_id,
|
| 175 |
+
)
|
| 176 |
+
await knowledge_graph_inst.upsert_node(
|
| 177 |
+
entity_name,
|
| 178 |
+
node_data=node_data,
|
| 179 |
+
)
|
| 180 |
+
node_data["entity_name"] = entity_name
|
| 181 |
+
return node_data
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
async def _merge_edges_then_upsert(
|
| 185 |
+
src_id: str,
|
| 186 |
+
tgt_id: str,
|
| 187 |
+
edges_data: list[dict],
|
| 188 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 189 |
+
global_config: dict,
|
| 190 |
+
):
|
| 191 |
+
already_weights = []
|
| 192 |
+
already_source_ids = []
|
| 193 |
+
already_description = []
|
| 194 |
+
already_keywords = []
|
| 195 |
+
|
| 196 |
+
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
| 197 |
+
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
| 198 |
+
already_weights.append(already_edge["weight"])
|
| 199 |
+
already_source_ids.extend(
|
| 200 |
+
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
|
| 201 |
+
)
|
| 202 |
+
already_description.append(already_edge["description"])
|
| 203 |
+
already_keywords.extend(
|
| 204 |
+
split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
weight = sum([dp["weight"] for dp in edges_data] + already_weights)
|
| 208 |
+
description = GRAPH_FIELD_SEP.join(
|
| 209 |
+
sorted(set([dp["description"] for dp in edges_data] + already_description))
|
| 210 |
+
)
|
| 211 |
+
keywords = GRAPH_FIELD_SEP.join(
|
| 212 |
+
sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
|
| 213 |
+
)
|
| 214 |
+
source_id = GRAPH_FIELD_SEP.join(
|
| 215 |
+
set([dp["source_id"] for dp in edges_data] + already_source_ids)
|
| 216 |
+
)
|
| 217 |
+
for need_insert_id in [src_id, tgt_id]:
|
| 218 |
+
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
| 219 |
+
await knowledge_graph_inst.upsert_node(
|
| 220 |
+
need_insert_id,
|
| 221 |
+
node_data={
|
| 222 |
+
"source_id": source_id,
|
| 223 |
+
"description": description,
|
| 224 |
+
"entity_type": '"UNKNOWN"',
|
| 225 |
+
},
|
| 226 |
+
)
|
| 227 |
+
description = await _handle_entity_relation_summary(
|
| 228 |
+
f"({src_id}, {tgt_id})", description, global_config
|
| 229 |
+
)
|
| 230 |
+
await knowledge_graph_inst.upsert_edge(
|
| 231 |
+
src_id,
|
| 232 |
+
tgt_id,
|
| 233 |
+
edge_data=dict(
|
| 234 |
+
weight=weight,
|
| 235 |
+
description=description,
|
| 236 |
+
keywords=keywords,
|
| 237 |
+
source_id=source_id,
|
| 238 |
+
),
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
edge_data = dict(
|
| 242 |
+
src_id=src_id,
|
| 243 |
+
tgt_id=tgt_id,
|
| 244 |
+
description=description,
|
| 245 |
+
keywords=keywords,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
return edge_data
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
async def extract_entities(
|
| 252 |
+
chunks: dict[str, TextChunkSchema],
|
| 253 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 254 |
+
entity_vdb: BaseVectorStorage,
|
| 255 |
+
relationships_vdb: BaseVectorStorage,
|
| 256 |
+
global_config: dict,
|
| 257 |
+
) -> Union[BaseGraphStorage, None]:
|
| 258 |
+
time.sleep(20)
|
| 259 |
+
use_llm_func: callable = global_config["llm_model_func"]
|
| 260 |
+
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
| 261 |
+
|
| 262 |
+
ordered_chunks = list(chunks.items())
|
| 263 |
+
|
| 264 |
+
language = global_config["addon_params"].get(
|
| 265 |
+
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
| 266 |
+
)
|
| 267 |
+
entity_types = global_config["addon_params"].get(
|
| 268 |
+
"entity_types", PROMPTS["DEFAULT_ENTITY_TYPES"]
|
| 269 |
+
)
|
| 270 |
+
example_number = global_config["addon_params"].get("example_number", None)
|
| 271 |
+
if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
|
| 272 |
+
examples = "\n".join(
|
| 273 |
+
PROMPTS["entity_extraction_examples"][: int(example_number)]
|
| 274 |
+
)
|
| 275 |
+
else:
|
| 276 |
+
examples = "\n".join(PROMPTS["entity_extraction_examples"])
|
| 277 |
+
|
| 278 |
+
example_context_base = dict(
|
| 279 |
+
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
| 280 |
+
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
| 281 |
+
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
| 282 |
+
entity_types=",".join(entity_types),
|
| 283 |
+
language=language,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
examples = examples.format(**example_context_base)
|
| 287 |
+
|
| 288 |
+
entity_extract_prompt = PROMPTS["entity_extraction"]
|
| 289 |
+
context_base = dict(
|
| 290 |
+
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
| 291 |
+
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
|
| 292 |
+
completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
| 293 |
+
entity_types=",".join(entity_types),
|
| 294 |
+
examples=examples,
|
| 295 |
+
language=language,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
continue_prompt = PROMPTS["entiti_continue_extraction"]
|
| 299 |
+
if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
|
| 300 |
+
|
| 301 |
+
already_processed = 0
|
| 302 |
+
already_entities = 0
|
| 303 |
+
already_relations = 0
|
| 304 |
+
|
| 305 |
+
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
|
| 306 |
+
nonlocal already_processed, already_entities, already_relations
|
| 307 |
+
chunk_key = chunk_key_dp[0]
|
| 308 |
+
chunk_dp = chunk_key_dp[1]
|
| 309 |
+
content = chunk_dp["content"]
|
| 310 |
+
hint_prompt = entity_extract_prompt.format(
|
| 311 |
+
**context_base, input_text="{input_text}"
|
| 312 |
+
).format(**context_base, input_text=content)
|
| 313 |
+
|
| 314 |
+
final_result = await use_llm_func(hint_prompt)
|
| 315 |
+
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
|
| 316 |
+
for now_glean_index in range(entity_extract_max_gleaning):
|
| 317 |
+
glean_result = await use_llm_func(continue_prompt, history_messages=history)
|
| 318 |
+
|
| 319 |
+
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
|
| 320 |
+
final_result += glean_result
|
| 321 |
+
if now_glean_index == entity_extract_max_gleaning - 1:
|
| 322 |
+
break
|
| 323 |
+
|
| 324 |
+
if_loop_result: str = await use_llm_func(
|
| 325 |
+
if_loop_prompt, history_messages=history
|
| 326 |
+
)
|
| 327 |
+
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
| 328 |
+
if if_loop_result != "yes":
|
| 329 |
+
break
|
| 330 |
+
|
| 331 |
+
records = split_string_by_multi_markers(
|
| 332 |
+
final_result,
|
| 333 |
+
[context_base["record_delimiter"], context_base["completion_delimiter"]],
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
maybe_nodes = defaultdict(list)
|
| 337 |
+
maybe_edges = defaultdict(list)
|
| 338 |
+
for record in records:
|
| 339 |
+
record = re.search(r"\((.*)\)", record)
|
| 340 |
+
if record is None:
|
| 341 |
+
continue
|
| 342 |
+
record = record.group(1)
|
| 343 |
+
record_attributes = split_string_by_multi_markers(
|
| 344 |
+
record, [context_base["tuple_delimiter"]]
|
| 345 |
+
)
|
| 346 |
+
if_entities = await _handle_single_entity_extraction(
|
| 347 |
+
record_attributes, chunk_key
|
| 348 |
+
)
|
| 349 |
+
if if_entities is not None:
|
| 350 |
+
maybe_nodes[if_entities["entity_name"]].append(if_entities)
|
| 351 |
+
continue
|
| 352 |
+
|
| 353 |
+
if_relation = await _handle_single_relationship_extraction(
|
| 354 |
+
record_attributes, chunk_key
|
| 355 |
+
)
|
| 356 |
+
if if_relation is not None:
|
| 357 |
+
maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
|
| 358 |
+
if_relation
|
| 359 |
+
)
|
| 360 |
+
already_processed += 1
|
| 361 |
+
already_entities += len(maybe_nodes)
|
| 362 |
+
already_relations += len(maybe_edges)
|
| 363 |
+
now_ticks = PROMPTS["process_tickers"][
|
| 364 |
+
already_processed % len(PROMPTS["process_tickers"])
|
| 365 |
+
]
|
| 366 |
+
print(
|
| 367 |
+
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
|
| 368 |
+
end="",
|
| 369 |
+
flush=True,
|
| 370 |
+
)
|
| 371 |
+
return dict(maybe_nodes), dict(maybe_edges)
|
| 372 |
+
|
| 373 |
+
results = []
|
| 374 |
+
for result in tqdm_async(
|
| 375 |
+
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
|
| 376 |
+
total=len(ordered_chunks),
|
| 377 |
+
desc="Extracting entities from chunks",
|
| 378 |
+
unit="chunk",
|
| 379 |
+
):
|
| 380 |
+
results.append(await result)
|
| 381 |
+
|
| 382 |
+
maybe_nodes = defaultdict(list)
|
| 383 |
+
maybe_edges = defaultdict(list)
|
| 384 |
+
for m_nodes, m_edges in results:
|
| 385 |
+
for k, v in m_nodes.items():
|
| 386 |
+
maybe_nodes[k].extend(v)
|
| 387 |
+
for k, v in m_edges.items():
|
| 388 |
+
maybe_edges[k].extend(v)
|
| 389 |
+
logger.info("Inserting entities into storage...")
|
| 390 |
+
all_entities_data = []
|
| 391 |
+
for result in tqdm_async(
|
| 392 |
+
asyncio.as_completed(
|
| 393 |
+
[
|
| 394 |
+
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
| 395 |
+
for k, v in maybe_nodes.items()
|
| 396 |
+
]
|
| 397 |
+
),
|
| 398 |
+
total=len(maybe_nodes),
|
| 399 |
+
desc="Inserting entities",
|
| 400 |
+
unit="entity",
|
| 401 |
+
):
|
| 402 |
+
all_entities_data.append(await result)
|
| 403 |
+
|
| 404 |
+
logger.info("Inserting relationships into storage...")
|
| 405 |
+
all_relationships_data = []
|
| 406 |
+
for result in tqdm_async(
|
| 407 |
+
asyncio.as_completed(
|
| 408 |
+
[
|
| 409 |
+
_merge_edges_then_upsert(
|
| 410 |
+
k[0], k[1], v, knowledge_graph_inst, global_config
|
| 411 |
+
)
|
| 412 |
+
for k, v in maybe_edges.items()
|
| 413 |
+
]
|
| 414 |
+
),
|
| 415 |
+
total=len(maybe_edges),
|
| 416 |
+
desc="Inserting relationships",
|
| 417 |
+
unit="relationship",
|
| 418 |
+
):
|
| 419 |
+
all_relationships_data.append(await result)
|
| 420 |
+
|
| 421 |
+
if not len(all_entities_data) and not len(all_relationships_data):
|
| 422 |
+
logger.warning(
|
| 423 |
+
"Didn't extract any entities and relationships, maybe your LLM is not working"
|
| 424 |
+
)
|
| 425 |
+
return None
|
| 426 |
+
|
| 427 |
+
if not len(all_entities_data):
|
| 428 |
+
logger.warning("Didn't extract any entities")
|
| 429 |
+
if not len(all_relationships_data):
|
| 430 |
+
logger.warning("Didn't extract any relationships")
|
| 431 |
+
|
| 432 |
+
if entity_vdb is not None:
|
| 433 |
+
data_for_vdb = {
|
| 434 |
+
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
|
| 435 |
+
"content": dp["entity_name"] + dp["description"],
|
| 436 |
+
"entity_name": dp["entity_name"],
|
| 437 |
+
}
|
| 438 |
+
for dp in all_entities_data
|
| 439 |
+
}
|
| 440 |
+
await entity_vdb.upsert(data_for_vdb)
|
| 441 |
+
|
| 442 |
+
if relationships_vdb is not None:
|
| 443 |
+
data_for_vdb = {
|
| 444 |
+
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
| 445 |
+
"src_id": dp["src_id"],
|
| 446 |
+
"tgt_id": dp["tgt_id"],
|
| 447 |
+
"content": dp["keywords"]
|
| 448 |
+
+ dp["src_id"]
|
| 449 |
+
+ dp["tgt_id"]
|
| 450 |
+
+ dp["description"],
|
| 451 |
+
}
|
| 452 |
+
for dp in all_relationships_data
|
| 453 |
+
}
|
| 454 |
+
await relationships_vdb.upsert(data_for_vdb)
|
| 455 |
+
|
| 456 |
+
return knowledge_graph_inst
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
async def kg_query(
|
| 461 |
+
query,
|
| 462 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 463 |
+
entities_vdb: BaseVectorStorage,
|
| 464 |
+
relationships_vdb: BaseVectorStorage,
|
| 465 |
+
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
| 466 |
+
query_param: QueryParam,
|
| 467 |
+
global_config: dict,
|
| 468 |
+
hashing_kv: BaseKVStorage = None,
|
| 469 |
+
) -> str:
|
| 470 |
+
|
| 471 |
+
use_model_func = global_config["llm_model_func"]
|
| 472 |
+
args_hash = compute_args_hash(query_param.mode, query)
|
| 473 |
+
cached_response, quantized, min_val, max_val = await handle_cache(
|
| 474 |
+
hashing_kv, args_hash, query, query_param.mode
|
| 475 |
+
)
|
| 476 |
+
if cached_response is not None:
|
| 477 |
+
return cached_response
|
| 478 |
+
|
| 479 |
+
example_number = global_config["addon_params"].get("example_number", None)
|
| 480 |
+
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
|
| 481 |
+
examples = "\n".join(
|
| 482 |
+
PROMPTS["keywords_extraction_examples"][: int(example_number)]
|
| 483 |
+
)
|
| 484 |
+
else:
|
| 485 |
+
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
|
| 486 |
+
language = global_config["addon_params"].get(
|
| 487 |
+
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
if query_param.mode not in ["hybrid"]:
|
| 491 |
+
logger.error(f"Unknown mode {query_param.mode} in kg_query")
|
| 492 |
+
return PROMPTS["fail_response"]
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
| 496 |
+
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
|
| 497 |
+
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
| 498 |
+
logger.info("kw_prompt result:")
|
| 499 |
+
print(result)
|
| 500 |
+
try:
|
| 501 |
+
|
| 502 |
+
match = re.search(r"\{.*\}", result, re.DOTALL)
|
| 503 |
+
if match:
|
| 504 |
+
result = match.group(0)
|
| 505 |
+
keywords_data = json.loads(result)
|
| 506 |
+
|
| 507 |
+
hl_keywords = keywords_data.get("high_level_keywords", [])
|
| 508 |
+
ll_keywords = keywords_data.get("low_level_keywords", [])
|
| 509 |
+
else:
|
| 510 |
+
logger.error("No JSON-like structure found in the result.")
|
| 511 |
+
return PROMPTS["fail_response"]
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
except json.JSONDecodeError as e:
|
| 515 |
+
print(f"JSON parsing error: {e} {result}")
|
| 516 |
+
return PROMPTS["fail_response"]
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
if hl_keywords == [] and ll_keywords == []:
|
| 520 |
+
logger.warning("low_level_keywords and high_level_keywords is empty")
|
| 521 |
+
return PROMPTS["fail_response"]
|
| 522 |
+
if ll_keywords == [] and query_param.mode in ["hybrid"]:
|
| 523 |
+
logger.warning("low_level_keywords is empty")
|
| 524 |
+
return PROMPTS["fail_response"]
|
| 525 |
+
else:
|
| 526 |
+
ll_keywords = ", ".join(ll_keywords)
|
| 527 |
+
if hl_keywords == [] and query_param.mode in ["hybrid"]:
|
| 528 |
+
logger.warning("high_level_keywords is empty")
|
| 529 |
+
return PROMPTS["fail_response"]
|
| 530 |
+
else:
|
| 531 |
+
hl_keywords = ", ".join(hl_keywords)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
keywords = [ll_keywords, hl_keywords]
|
| 535 |
+
context= await _build_query_context(
|
| 536 |
+
keywords,
|
| 537 |
+
knowledge_graph_inst,
|
| 538 |
+
entities_vdb,
|
| 539 |
+
relationships_vdb,
|
| 540 |
+
text_chunks_db,
|
| 541 |
+
query_param,
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
if query_param.only_need_context:
|
| 547 |
+
return context
|
| 548 |
+
if context is None:
|
| 549 |
+
return PROMPTS["fail_response"]
|
| 550 |
+
sys_prompt_temp = PROMPTS["rag_response"]
|
| 551 |
+
sys_prompt = sys_prompt_temp.format(
|
| 552 |
+
context_data=context, response_type=query_param.response_type
|
| 553 |
+
)
|
| 554 |
+
if query_param.only_need_prompt:
|
| 555 |
+
return sys_prompt
|
| 556 |
+
response = await use_model_func(
|
| 557 |
+
query,
|
| 558 |
+
system_prompt=sys_prompt,
|
| 559 |
+
stream=query_param.stream,
|
| 560 |
+
)
|
| 561 |
+
if isinstance(response, str) and len(response) > len(sys_prompt):
|
| 562 |
+
response = (
|
| 563 |
+
response.replace(sys_prompt, "")
|
| 564 |
+
.replace("user", "")
|
| 565 |
+
.replace("model", "")
|
| 566 |
+
.replace(query, "")
|
| 567 |
+
.replace("<system>", "")
|
| 568 |
+
.replace("</system>", "")
|
| 569 |
+
.strip()
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
await save_to_cache(
|
| 574 |
+
hashing_kv,
|
| 575 |
+
CacheData(
|
| 576 |
+
args_hash=args_hash,
|
| 577 |
+
content=response,
|
| 578 |
+
prompt=query,
|
| 579 |
+
quantized=quantized,
|
| 580 |
+
min_val=min_val,
|
| 581 |
+
max_val=max_val,
|
| 582 |
+
mode=query_param.mode,
|
| 583 |
+
),
|
| 584 |
+
)
|
| 585 |
+
return response
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
async def _build_query_context(
|
| 589 |
+
query: list,
|
| 590 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 591 |
+
entities_vdb: BaseVectorStorage,
|
| 592 |
+
relationships_vdb: BaseVectorStorage,
|
| 593 |
+
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
| 594 |
+
query_param: QueryParam,
|
| 595 |
+
):
|
| 596 |
+
ll_entities_context, ll_relations_context, ll_text_units_context = "", "", ""
|
| 597 |
+
hl_entities_context, hl_relations_context, hl_text_units_context = "", "", ""
|
| 598 |
+
|
| 599 |
+
ll_kewwords, hl_keywrds = query[0], query[1]
|
| 600 |
+
if query_param.mode in ["local", "hybrid"]:
|
| 601 |
+
if ll_kewwords == "":
|
| 602 |
+
ll_entities_context, ll_relations_context, ll_text_units_context = (
|
| 603 |
+
"",
|
| 604 |
+
"",
|
| 605 |
+
"",
|
| 606 |
+
)
|
| 607 |
+
warnings.warn(
|
| 608 |
+
"Low Level context is None. Return empty Low entity/relationship/source"
|
| 609 |
+
)
|
| 610 |
+
query_param.mode = "global"
|
| 611 |
+
else:
|
| 612 |
+
(
|
| 613 |
+
ll_entities_context,
|
| 614 |
+
ll_relations_context,
|
| 615 |
+
ll_text_units_context,
|
| 616 |
+
) = await _get_node_data(
|
| 617 |
+
ll_kewwords,
|
| 618 |
+
knowledge_graph_inst,
|
| 619 |
+
entities_vdb,
|
| 620 |
+
text_chunks_db,
|
| 621 |
+
query_param,
|
| 622 |
+
)
|
| 623 |
+
if query_param.mode in ["hybrid"]:
|
| 624 |
+
if hl_keywrds == "":
|
| 625 |
+
hl_entities_context, hl_relations_context, hl_text_units_context = (
|
| 626 |
+
"",
|
| 627 |
+
"",
|
| 628 |
+
"",
|
| 629 |
+
)
|
| 630 |
+
warnings.warn(
|
| 631 |
+
"High Level context is None. Return empty High entity/relationship/source"
|
| 632 |
+
)
|
| 633 |
+
query_param.mode = "local"
|
| 634 |
+
else:
|
| 635 |
+
(
|
| 636 |
+
hl_entities_context,
|
| 637 |
+
hl_relations_context,
|
| 638 |
+
hl_text_units_context,
|
| 639 |
+
) = await _get_edge_data(
|
| 640 |
+
hl_keywrds,
|
| 641 |
+
knowledge_graph_inst,
|
| 642 |
+
relationships_vdb,
|
| 643 |
+
text_chunks_db,
|
| 644 |
+
query_param,
|
| 645 |
+
)
|
| 646 |
+
if (
|
| 647 |
+
hl_entities_context == ""
|
| 648 |
+
and hl_relations_context == ""
|
| 649 |
+
and hl_text_units_context == ""
|
| 650 |
+
):
|
| 651 |
+
logger.warn("No high level context found. Switching to local mode.")
|
| 652 |
+
query_param.mode = "local"
|
| 653 |
+
if query_param.mode == "hybrid":
|
| 654 |
+
entities_context, relations_context, text_units_context = combine_contexts(
|
| 655 |
+
[hl_entities_context, hl_relations_context],
|
| 656 |
+
[ll_entities_context, ll_relations_context],
|
| 657 |
+
[hl_text_units_context, ll_text_units_context],
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
return f"""
|
| 662 |
+
-----global-information-----
|
| 663 |
+
-----high-level entity information-----
|
| 664 |
+
```csv
|
| 665 |
+
{hl_entities_context}
|
| 666 |
+
```
|
| 667 |
+
-----high-level relationship information-----
|
| 668 |
+
```csv
|
| 669 |
+
{hl_relations_context}
|
| 670 |
+
```
|
| 671 |
+
-----Sources-----
|
| 672 |
+
```csv
|
| 673 |
+
{text_units_context}
|
| 674 |
+
```
|
| 675 |
+
-----local-information-----
|
| 676 |
+
-----low-level entity information-----
|
| 677 |
+
```csv
|
| 678 |
+
{ll_entities_context}
|
| 679 |
+
```
|
| 680 |
+
-----low-level relationship information-----
|
| 681 |
+
```csv
|
| 682 |
+
{ll_relations_context}
|
| 683 |
+
```
|
| 684 |
+
"""
|
| 685 |
+
|
| 686 |
+
async def _get_node_data(
|
| 687 |
+
query,
|
| 688 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 689 |
+
entities_vdb: BaseVectorStorage,
|
| 690 |
+
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
| 691 |
+
query_param: QueryParam,
|
| 692 |
+
):
|
| 693 |
+
|
| 694 |
+
results = await entities_vdb.query(query, top_k=query_param.top_k)
|
| 695 |
+
if not len(results):
|
| 696 |
+
return "", "", ""
|
| 697 |
+
|
| 698 |
+
node_datas = await asyncio.gather(
|
| 699 |
+
*[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
|
| 700 |
+
)
|
| 701 |
+
if not all([n is not None for n in node_datas]):
|
| 702 |
+
logger.warning("Some nodes are missing, maybe the storage is damaged")
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
node_degrees = await asyncio.gather(
|
| 706 |
+
*[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
|
| 707 |
+
)
|
| 708 |
+
node_datas = [
|
| 709 |
+
{**n, "entity_name": k["entity_name"], "rank": d}
|
| 710 |
+
for k, n, d in zip(results, node_datas, node_degrees)
|
| 711 |
+
if n is not None
|
| 712 |
+
]
|
| 713 |
+
use_text_units = await _find_most_related_text_unit_from_entities(
|
| 714 |
+
node_datas, query_param, text_chunks_db, knowledge_graph_inst
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
use_relations= await _find_most_related_edges_from_entities3(
|
| 719 |
+
node_datas, query_param, knowledge_graph_inst
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
logger.info(
|
| 723 |
+
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
| 728 |
+
for i, n in enumerate(node_datas):
|
| 729 |
+
entites_section_list.append(
|
| 730 |
+
[
|
| 731 |
+
i,
|
| 732 |
+
n["entity_name"],
|
| 733 |
+
n.get("entity_type", "UNKNOWN"),
|
| 734 |
+
n.get("description", "UNKNOWN"),
|
| 735 |
+
n["rank"],
|
| 736 |
+
]
|
| 737 |
+
)
|
| 738 |
+
entities_context = list_of_list_to_csv(entites_section_list)
|
| 739 |
+
|
| 740 |
+
relations_section_list=[["id","context"]]
|
| 741 |
+
for i,e in enumerate(use_relations):
|
| 742 |
+
relations_section_list.append([i,e])
|
| 743 |
+
relations_context=list_of_list_to_csv(relations_section_list)
|
| 744 |
+
|
| 745 |
+
text_units_section_list = [["id", "content"]]
|
| 746 |
+
for i, t in enumerate(use_text_units):
|
| 747 |
+
text_units_section_list.append([i, t["content"]])
|
| 748 |
+
text_units_context = list_of_list_to_csv(text_units_section_list)
|
| 749 |
+
|
| 750 |
+
return entities_context,relations_context,text_units_context
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
async def _find_most_related_text_unit_from_entities(
|
| 754 |
+
node_datas: list[dict],
|
| 755 |
+
query_param: QueryParam,
|
| 756 |
+
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
| 757 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 758 |
+
):
|
| 759 |
+
text_units = [
|
| 760 |
+
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
| 761 |
+
for dp in node_datas
|
| 762 |
+
]
|
| 763 |
+
edges = await asyncio.gather(
|
| 764 |
+
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
| 765 |
+
)
|
| 766 |
+
all_one_hop_nodes = set()
|
| 767 |
+
for this_edges in edges:
|
| 768 |
+
if not this_edges:
|
| 769 |
+
continue
|
| 770 |
+
all_one_hop_nodes.update([e[1] for e in this_edges])
|
| 771 |
+
|
| 772 |
+
all_one_hop_nodes = list(all_one_hop_nodes)
|
| 773 |
+
all_one_hop_nodes_data = await asyncio.gather(
|
| 774 |
+
*[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
all_one_hop_text_units_lookup = {
|
| 779 |
+
k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP]))
|
| 780 |
+
for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data)
|
| 781 |
+
if v is not None and "source_id" in v
|
| 782 |
+
}
|
| 783 |
+
|
| 784 |
+
all_text_units_lookup = {}
|
| 785 |
+
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
|
| 786 |
+
for c_id in this_text_units:
|
| 787 |
+
if c_id not in all_text_units_lookup:
|
| 788 |
+
all_text_units_lookup[c_id] = {
|
| 789 |
+
"data": await text_chunks_db.get_by_id(c_id),
|
| 790 |
+
"order": index,
|
| 791 |
+
"relation_counts": 0,
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
if this_edges:
|
| 795 |
+
for e in this_edges:
|
| 796 |
+
if (
|
| 797 |
+
e[1] in all_one_hop_text_units_lookup
|
| 798 |
+
and c_id in all_one_hop_text_units_lookup[e[1]]
|
| 799 |
+
):
|
| 800 |
+
all_text_units_lookup[c_id]["relation_counts"] += 1
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
all_text_units = [
|
| 804 |
+
{"id": k, **v}
|
| 805 |
+
for k, v in all_text_units_lookup.items()
|
| 806 |
+
if v is not None and v.get("data") is not None and "content" in v["data"]
|
| 807 |
+
]
|
| 808 |
+
|
| 809 |
+
if not all_text_units:
|
| 810 |
+
logger.warning("No valid text units found")
|
| 811 |
+
return []
|
| 812 |
+
|
| 813 |
+
all_text_units = sorted(
|
| 814 |
+
all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
all_text_units = truncate_list_by_token_size(
|
| 818 |
+
all_text_units,
|
| 819 |
+
key=lambda x: x["data"]["content"],
|
| 820 |
+
max_token_size=query_param.max_token_for_text_unit,
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
all_text_units = [t["data"] for t in all_text_units]
|
| 824 |
+
return all_text_units
|
| 825 |
+
|
| 826 |
+
async def _get_edge_data(
|
| 827 |
+
keywords,
|
| 828 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 829 |
+
relationships_vdb: BaseVectorStorage,
|
| 830 |
+
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
| 831 |
+
query_param: QueryParam,
|
| 832 |
+
):
|
| 833 |
+
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
| 834 |
+
|
| 835 |
+
if not len(results):
|
| 836 |
+
return "", "", ""
|
| 837 |
+
|
| 838 |
+
edge_datas = await asyncio.gather(
|
| 839 |
+
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
if not all([n is not None for n in edge_datas]):
|
| 843 |
+
logger.warning("Some edges are missing, maybe the storage is damaged")
|
| 844 |
+
edge_degree = await asyncio.gather(
|
| 845 |
+
*[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results]
|
| 846 |
+
)
|
| 847 |
+
edge_datas = [
|
| 848 |
+
{"src_id": k["src_id"], "tgt_id": k["tgt_id"], "rank": d, **v}
|
| 849 |
+
for k, v, d in zip(results, edge_datas, edge_degree)
|
| 850 |
+
if v is not None
|
| 851 |
+
]
|
| 852 |
+
edge_datas = sorted(
|
| 853 |
+
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
| 854 |
+
)
|
| 855 |
+
edge_datas = truncate_list_by_token_size(
|
| 856 |
+
edge_datas,
|
| 857 |
+
key=lambda x: x["description"],
|
| 858 |
+
max_token_size=query_param.max_token_for_global_context,
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
use_entities = await _find_most_related_entities_from_relationships(
|
| 862 |
+
edge_datas, query_param, knowledge_graph_inst
|
| 863 |
+
)
|
| 864 |
+
use_text_units = await _find_related_text_unit_from_relationships(
|
| 865 |
+
edge_datas, query_param, text_chunks_db, knowledge_graph_inst
|
| 866 |
+
)
|
| 867 |
+
logger.info(
|
| 868 |
+
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
relations_section_list = [
|
| 872 |
+
["id", "source", "target", "description", "keywords", "weight", "rank"]
|
| 873 |
+
]
|
| 874 |
+
for i, e in enumerate(edge_datas):
|
| 875 |
+
relations_section_list.append(
|
| 876 |
+
[
|
| 877 |
+
i,
|
| 878 |
+
e["src_id"],
|
| 879 |
+
e["tgt_id"],
|
| 880 |
+
e["description"],
|
| 881 |
+
e["keywords"],
|
| 882 |
+
e["weight"],
|
| 883 |
+
e["rank"],
|
| 884 |
+
]
|
| 885 |
+
)
|
| 886 |
+
relations_context = list_of_list_to_csv(relations_section_list)
|
| 887 |
+
|
| 888 |
+
entites_section_list = [["id", "entity", "type", "description", "rank"]]
|
| 889 |
+
for i, n in enumerate(use_entities):
|
| 890 |
+
entites_section_list.append(
|
| 891 |
+
[
|
| 892 |
+
i,
|
| 893 |
+
n["entity_name"],
|
| 894 |
+
n.get("entity_type", "UNKNOWN"),
|
| 895 |
+
n.get("description", "UNKNOWN"),
|
| 896 |
+
n["rank"],
|
| 897 |
+
]
|
| 898 |
+
)
|
| 899 |
+
entities_context = list_of_list_to_csv(entites_section_list)
|
| 900 |
+
|
| 901 |
+
text_units_section_list = [["id", "content"]]
|
| 902 |
+
for i, t in enumerate(use_text_units):
|
| 903 |
+
text_units_section_list.append([i, t["content"]])
|
| 904 |
+
text_units_context = list_of_list_to_csv(text_units_section_list)
|
| 905 |
+
return entities_context, relations_context, text_units_context
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
async def _find_most_related_entities_from_relationships(
|
| 909 |
+
edge_datas: list[dict],
|
| 910 |
+
query_param: QueryParam,
|
| 911 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 912 |
+
):
|
| 913 |
+
entity_names = []
|
| 914 |
+
seen = set()
|
| 915 |
+
|
| 916 |
+
for e in edge_datas:
|
| 917 |
+
if e["src_id"] not in seen:
|
| 918 |
+
entity_names.append(e["src_id"])
|
| 919 |
+
seen.add(e["src_id"])
|
| 920 |
+
if e["tgt_id"] not in seen:
|
| 921 |
+
entity_names.append(e["tgt_id"])
|
| 922 |
+
seen.add(e["tgt_id"])
|
| 923 |
+
|
| 924 |
+
node_datas = await asyncio.gather(
|
| 925 |
+
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
| 926 |
+
)
|
| 927 |
+
|
| 928 |
+
node_degrees = await asyncio.gather(
|
| 929 |
+
*[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names]
|
| 930 |
+
)
|
| 931 |
+
node_datas = [
|
| 932 |
+
{**n, "entity_name": k, "rank": d}
|
| 933 |
+
for k, n, d in zip(entity_names, node_datas, node_degrees)
|
| 934 |
+
]
|
| 935 |
+
|
| 936 |
+
node_datas = truncate_list_by_token_size(
|
| 937 |
+
node_datas,
|
| 938 |
+
key=lambda x: x["description"],
|
| 939 |
+
max_token_size=query_param.max_token_for_local_context,
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
return node_datas
|
| 943 |
+
|
| 944 |
+
|
| 945 |
+
async def _find_related_text_unit_from_relationships(
|
| 946 |
+
edge_datas: list[dict],
|
| 947 |
+
query_param: QueryParam,
|
| 948 |
+
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
| 949 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 950 |
+
):
|
| 951 |
+
text_units = [
|
| 952 |
+
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
| 953 |
+
for dp in edge_datas
|
| 954 |
+
]
|
| 955 |
+
all_text_units_lookup = {}
|
| 956 |
+
|
| 957 |
+
for index, unit_list in enumerate(text_units):
|
| 958 |
+
for c_id in unit_list:
|
| 959 |
+
if c_id not in all_text_units_lookup:
|
| 960 |
+
chunk_data = await text_chunks_db.get_by_id(c_id)
|
| 961 |
+
|
| 962 |
+
if chunk_data is not None and "content" in chunk_data:
|
| 963 |
+
all_text_units_lookup[c_id] = {
|
| 964 |
+
"data": chunk_data,
|
| 965 |
+
"order": index,
|
| 966 |
+
}
|
| 967 |
+
|
| 968 |
+
if not all_text_units_lookup:
|
| 969 |
+
logger.warning("No valid text chunks found")
|
| 970 |
+
return []
|
| 971 |
+
|
| 972 |
+
all_text_units = [{"id": k, **v} for k, v in all_text_units_lookup.items()]
|
| 973 |
+
all_text_units = sorted(all_text_units, key=lambda x: x["order"])
|
| 974 |
+
|
| 975 |
+
|
| 976 |
+
valid_text_units = [
|
| 977 |
+
t for t in all_text_units if t["data"] is not None and "content" in t["data"]
|
| 978 |
+
]
|
| 979 |
+
|
| 980 |
+
if not valid_text_units:
|
| 981 |
+
logger.warning("No valid text chunks after filtering")
|
| 982 |
+
return []
|
| 983 |
+
|
| 984 |
+
truncated_text_units = truncate_list_by_token_size(
|
| 985 |
+
valid_text_units,
|
| 986 |
+
key=lambda x: x["data"]["content"],
|
| 987 |
+
max_token_size=query_param.max_token_for_text_unit,
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units]
|
| 991 |
+
|
| 992 |
+
return all_text_units
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
def combine_contexts(entities, relationships, sources):
|
| 996 |
+
|
| 997 |
+
hl_entities, ll_entities = entities[0], entities[1]
|
| 998 |
+
hl_relationships, ll_relationships = relationships[0], relationships[1]
|
| 999 |
+
hl_sources, ll_sources = sources[0], sources[1]
|
| 1000 |
+
|
| 1001 |
+
combined_entities = process_combine_contexts(hl_entities, ll_entities)
|
| 1002 |
+
|
| 1003 |
+
combined_relationships = process_combine_contexts(
|
| 1004 |
+
hl_relationships, ll_relationships
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
combined_sources = process_combine_contexts(hl_sources, ll_sources)
|
| 1008 |
+
|
| 1009 |
+
return combined_entities, combined_relationships, combined_sources
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
import networkx as nx
|
| 1013 |
+
from collections import defaultdict
|
| 1014 |
+
async def find_paths_and_edges_with_stats(graph, target_nodes):
|
| 1015 |
+
|
| 1016 |
+
result = defaultdict(lambda: {"paths": [], "edges": set()})
|
| 1017 |
+
path_stats = {"1-hop": 0, "2-hop": 0, "3-hop": 0}
|
| 1018 |
+
one_hop_paths = []
|
| 1019 |
+
two_hop_paths = []
|
| 1020 |
+
three_hop_paths = []
|
| 1021 |
+
|
| 1022 |
+
async def dfs(current, target, path, depth):
|
| 1023 |
+
|
| 1024 |
+
if depth > 3:
|
| 1025 |
+
return
|
| 1026 |
+
if current == target:
|
| 1027 |
+
result[(path[0], target)]["paths"].append(list(path))
|
| 1028 |
+
for u, v in zip(path[:-1], path[1:]):
|
| 1029 |
+
result[(path[0], target)]["edges"].add(tuple(sorted((u, v))))
|
| 1030 |
+
if depth == 1:
|
| 1031 |
+
path_stats["1-hop"] += 1
|
| 1032 |
+
one_hop_paths.append(list(path))
|
| 1033 |
+
elif depth == 2:
|
| 1034 |
+
path_stats["2-hop"] += 1
|
| 1035 |
+
two_hop_paths.append(list(path))
|
| 1036 |
+
elif depth == 3:
|
| 1037 |
+
path_stats["3-hop"] += 1
|
| 1038 |
+
three_hop_paths.append(list(path))
|
| 1039 |
+
return
|
| 1040 |
+
neighbors = graph.neighbors(current)
|
| 1041 |
+
for neighbor in neighbors:
|
| 1042 |
+
if neighbor not in path:
|
| 1043 |
+
await dfs(neighbor, target, path + [neighbor], depth + 1)
|
| 1044 |
+
|
| 1045 |
+
for node1 in target_nodes:
|
| 1046 |
+
for node2 in target_nodes:
|
| 1047 |
+
if node1 != node2:
|
| 1048 |
+
await dfs(node1, node2, [node1], 0)
|
| 1049 |
+
|
| 1050 |
+
for key in result:
|
| 1051 |
+
result[key]["edges"] = list(result[key]["edges"])
|
| 1052 |
+
|
| 1053 |
+
return dict(result), path_stats , one_hop_paths, two_hop_paths, three_hop_paths
|
| 1054 |
+
def bfs_weighted_paths(G, path, source, target, threshold, alpha):
|
| 1055 |
+
results = []
|
| 1056 |
+
edge_weights = defaultdict(float)
|
| 1057 |
+
node = source
|
| 1058 |
+
follow_dict = {}
|
| 1059 |
+
|
| 1060 |
+
for p in path:
|
| 1061 |
+
for i in range(len(p) - 1):
|
| 1062 |
+
current = p[i]
|
| 1063 |
+
next_num = p[i + 1]
|
| 1064 |
+
|
| 1065 |
+
if current in follow_dict:
|
| 1066 |
+
follow_dict[current].add(next_num)
|
| 1067 |
+
else:
|
| 1068 |
+
follow_dict[current] = {next_num}
|
| 1069 |
+
|
| 1070 |
+
for neighbor in follow_dict[node]:
|
| 1071 |
+
edge_weights[(node, neighbor)] += 1/len(follow_dict[node])
|
| 1072 |
+
|
| 1073 |
+
if neighbor == target:
|
| 1074 |
+
results.append(([node, neighbor]))
|
| 1075 |
+
continue
|
| 1076 |
+
|
| 1077 |
+
if edge_weights[(node, neighbor)] > threshold:
|
| 1078 |
+
|
| 1079 |
+
for second_neighbor in follow_dict[neighbor]:
|
| 1080 |
+
weight = edge_weights[(node, neighbor)] * alpha / len(follow_dict[neighbor])
|
| 1081 |
+
edge_weights[(neighbor, second_neighbor)] += weight
|
| 1082 |
+
|
| 1083 |
+
if second_neighbor == target:
|
| 1084 |
+
results.append(([node, neighbor, second_neighbor]))
|
| 1085 |
+
continue
|
| 1086 |
+
|
| 1087 |
+
if edge_weights[(neighbor, second_neighbor)] > threshold:
|
| 1088 |
+
|
| 1089 |
+
for third_neighbor in follow_dict[second_neighbor]:
|
| 1090 |
+
weight = edge_weights[(neighbor, second_neighbor)] * alpha / len(follow_dict[second_neighbor])
|
| 1091 |
+
edge_weights[(second_neighbor, third_neighbor)] += weight
|
| 1092 |
+
|
| 1093 |
+
if third_neighbor == target :
|
| 1094 |
+
results.append(([node, neighbor, second_neighbor, third_neighbor]))
|
| 1095 |
+
continue
|
| 1096 |
+
path_weights = []
|
| 1097 |
+
for p in path:
|
| 1098 |
+
path_weight = 0
|
| 1099 |
+
for i in range(len(p) - 1):
|
| 1100 |
+
edge = (p[i], p[i + 1])
|
| 1101 |
+
path_weight += edge_weights.get(edge, 0)
|
| 1102 |
+
path_weights.append(path_weight/(len(p)-1))
|
| 1103 |
+
|
| 1104 |
+
combined = [(p, w) for p, w in zip(path, path_weights)]
|
| 1105 |
+
|
| 1106 |
+
return combined
|
| 1107 |
+
async def _find_most_related_edges_from_entities3(
|
| 1108 |
+
node_datas: list[dict],
|
| 1109 |
+
query_param: QueryParam,
|
| 1110 |
+
knowledge_graph_inst: BaseGraphStorage,
|
| 1111 |
+
):
|
| 1112 |
+
|
| 1113 |
+
G = nx.Graph()
|
| 1114 |
+
edges = await knowledge_graph_inst.edges()
|
| 1115 |
+
nodes = await knowledge_graph_inst.nodes()
|
| 1116 |
+
|
| 1117 |
+
for u, v in edges:
|
| 1118 |
+
G.add_edge(u, v)
|
| 1119 |
+
G.add_nodes_from(nodes)
|
| 1120 |
+
source_nodes = [dp["entity_name"] for dp in node_datas]
|
| 1121 |
+
result, path_stats, one_hop_paths, two_hop_paths, three_hop_paths = await find_paths_and_edges_with_stats(G, source_nodes)
|
| 1122 |
+
|
| 1123 |
+
|
| 1124 |
+
threshold = 0.3
|
| 1125 |
+
alpha = 0.8
|
| 1126 |
+
all_results = []
|
| 1127 |
+
|
| 1128 |
+
for node1 in source_nodes:
|
| 1129 |
+
for node2 in source_nodes:
|
| 1130 |
+
if node1 != node2:
|
| 1131 |
+
if (node1, node2) in result:
|
| 1132 |
+
sub_G = nx.Graph()
|
| 1133 |
+
paths = result[(node1,node2)]['paths']
|
| 1134 |
+
edges = result[(node1,node2)]['edges']
|
| 1135 |
+
sub_G.add_edges_from(edges)
|
| 1136 |
+
results = bfs_weighted_paths(G, paths, node1, node2, threshold, alpha)
|
| 1137 |
+
all_results+= results
|
| 1138 |
+
all_results = sorted(all_results, key=lambda x: x[1], reverse=True)
|
| 1139 |
+
seen = set()
|
| 1140 |
+
result_edge = []
|
| 1141 |
+
for edge, weight in all_results:
|
| 1142 |
+
sorted_edge = tuple(sorted(edge))
|
| 1143 |
+
if sorted_edge not in seen:
|
| 1144 |
+
seen.add(sorted_edge)
|
| 1145 |
+
result_edge.append((edge, weight))
|
| 1146 |
+
|
| 1147 |
+
|
| 1148 |
+
length_1 = int(len(one_hop_paths)/2)
|
| 1149 |
+
length_2 = int(len(two_hop_paths)/2)
|
| 1150 |
+
length_3 = int(len(three_hop_paths)/2)
|
| 1151 |
+
results = []
|
| 1152 |
+
if one_hop_paths!=[]:
|
| 1153 |
+
results = one_hop_paths[0:length_1]
|
| 1154 |
+
if two_hop_paths!=[]:
|
| 1155 |
+
results = results + two_hop_paths[0:length_2]
|
| 1156 |
+
if three_hop_paths!=[]:
|
| 1157 |
+
results =results + three_hop_paths[0:length_3]
|
| 1158 |
+
|
| 1159 |
+
length = len(results)
|
| 1160 |
+
total_edges = 15
|
| 1161 |
+
if length < total_edges:
|
| 1162 |
+
total_edges = length
|
| 1163 |
+
sort_result = []
|
| 1164 |
+
if result_edge:
|
| 1165 |
+
if len(result_edge)>total_edges:
|
| 1166 |
+
sort_result = result_edge[0:total_edges]
|
| 1167 |
+
else :
|
| 1168 |
+
sort_result = result_edge
|
| 1169 |
+
final_result = []
|
| 1170 |
+
for edge, weight in sort_result:
|
| 1171 |
+
final_result.append(edge)
|
| 1172 |
+
|
| 1173 |
+
relationship = []
|
| 1174 |
+
|
| 1175 |
+
for path in final_result:
|
| 1176 |
+
if len(path) == 4:
|
| 1177 |
+
s_name,b1_name,b2_name,t_name = path[0],path[1],path[2],path[3]
|
| 1178 |
+
edge0 = await knowledge_graph_inst.get_edge(path[0], path[1]) or await knowledge_graph_inst.get_edge(path[1], path[0])
|
| 1179 |
+
edge1 = await knowledge_graph_inst.get_edge(path[1],path[2]) or await knowledge_graph_inst.get_edge(path[2], path[1])
|
| 1180 |
+
edge2 = await knowledge_graph_inst.get_edge(path[2],path[3]) or await knowledge_graph_inst.get_edge(path[3], path[2])
|
| 1181 |
+
if edge0==None or edge1==None or edge2==None:
|
| 1182 |
+
print(path,"边丢失")
|
| 1183 |
+
if edge0==None:
|
| 1184 |
+
print("edge0丢失")
|
| 1185 |
+
if edge1==None:
|
| 1186 |
+
print("edge1丢失")
|
| 1187 |
+
if edge2==None:
|
| 1188 |
+
print("edge2丢失")
|
| 1189 |
+
continue
|
| 1190 |
+
e1 = "through edge ("+edge0["keywords"]+") to connect to "+s_name+" and "+b1_name+"."
|
| 1191 |
+
e2 = "through edge ("+edge1["keywords"]+") to connect to "+b1_name+" and "+b2_name+"."
|
| 1192 |
+
e3 = "through edge ("+edge2["keywords"]+") to connect to "+b2_name+" and "+t_name+"."
|
| 1193 |
+
s = await knowledge_graph_inst.get_node(s_name)
|
| 1194 |
+
s = "The entity "+s_name+" is a "+s["entity_type"]+" with the description("+s["description"]+")"
|
| 1195 |
+
b1 = await knowledge_graph_inst.get_node(b1_name)
|
| 1196 |
+
b1 = "The entity "+b1_name+" is a "+b1["entity_type"]+" with the description("+b1["description"]+")"
|
| 1197 |
+
b2 = await knowledge_graph_inst.get_node(b2_name)
|
| 1198 |
+
b2 = "The entity "+b2_name+" is a "+b2["entity_type"]+" with the description("+b2["description"]+")"
|
| 1199 |
+
t = await knowledge_graph_inst.get_node(t_name)
|
| 1200 |
+
t = "The entity "+t_name+" is a "+t["entity_type"]+" with the description("+t["description"]+")"
|
| 1201 |
+
relationship.append([s+e1+b1+"and"+b1+e2+b2+"and"+b2+e3+t])
|
| 1202 |
+
elif len(path) == 3:
|
| 1203 |
+
s_name,b_name,t_name = path[0],path[1],path[2]
|
| 1204 |
+
edge0 = await knowledge_graph_inst.get_edge(path[0], path[1]) or await knowledge_graph_inst.get_edge(path[1], path[0])
|
| 1205 |
+
edge1 = await knowledge_graph_inst.get_edge(path[1],path[2]) or await knowledge_graph_inst.get_edge(path[2], path[1])
|
| 1206 |
+
if edge0==None or edge1==None:
|
| 1207 |
+
print(path,"边丢失")
|
| 1208 |
+
continue
|
| 1209 |
+
e1 = "through edge("+edge0["keywords"]+") to connect to "+s_name+" and "+b_name+"."
|
| 1210 |
+
e2 = "through edge("+edge1["keywords"]+") to connect to "+b_name+" and "+t_name+"."
|
| 1211 |
+
s = await knowledge_graph_inst.get_node(s_name)
|
| 1212 |
+
s = "The entity "+s_name+" is a "+s["entity_type"]+" with the description("+s["description"]+")"
|
| 1213 |
+
b = await knowledge_graph_inst.get_node(b_name)
|
| 1214 |
+
b = "The entity "+b_name+" is a "+b["entity_type"]+" with the description("+b["description"]+")"
|
| 1215 |
+
t = await knowledge_graph_inst.get_node(t_name)
|
| 1216 |
+
t = "The entity "+t_name+" is a "+t["entity_type"]+" with the description("+t["description"]+")"
|
| 1217 |
+
relationship.append([s+e1+b+"and"+b+e2+t])
|
| 1218 |
+
elif len(path) == 2:
|
| 1219 |
+
s_name,t_name = path[0],path[1]
|
| 1220 |
+
edge0 = await knowledge_graph_inst.get_edge(path[0], path[1]) or await knowledge_graph_inst.get_edge(path[1], path[0])
|
| 1221 |
+
if edge0==None:
|
| 1222 |
+
print(path,"边丢失")
|
| 1223 |
+
continue
|
| 1224 |
+
e = "through edge("+edge0["keywords"]+") to connect to "+s_name+" and "+t_name+"."
|
| 1225 |
+
s = await knowledge_graph_inst.get_node(s_name)
|
| 1226 |
+
s = "The entity "+s_name+" is a "+s["entity_type"]+" with the description("+s["description"]+")"
|
| 1227 |
+
t = await knowledge_graph_inst.get_node(t_name)
|
| 1228 |
+
t = "The entity "+t_name+" is a "+t["entity_type"]+" with the description("+t["description"]+")"
|
| 1229 |
+
relationship.append([s+e+t])
|
| 1230 |
+
|
| 1231 |
+
|
| 1232 |
+
relationship = truncate_list_by_token_size(
|
| 1233 |
+
relationship,
|
| 1234 |
+
key=lambda x: x[0],
|
| 1235 |
+
max_token_size=query_param.max_token_for_local_context,
|
| 1236 |
+
)
|
| 1237 |
+
|
| 1238 |
+
reversed_relationship = relationship[::-1]
|
| 1239 |
+
return reversed_relationship
|
PathRAG/prompt.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
GRAPH_FIELD_SEP = "<SEP>"
|
| 2 |
+
|
| 3 |
+
PROMPTS = {}
|
| 4 |
+
|
| 5 |
+
PROMPTS["DEFAULT_LANGUAGE"] = "English"
|
| 6 |
+
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
|
| 7 |
+
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
|
| 8 |
+
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
|
| 9 |
+
PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
| 10 |
+
|
| 11 |
+
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"]
|
| 12 |
+
|
| 13 |
+
PROMPTS["entity_extraction"] = """-Goal-
|
| 14 |
+
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
| 15 |
+
Use {language} as output language.
|
| 16 |
+
|
| 17 |
+
-Steps-
|
| 18 |
+
1. Identify all entities. For each identified entity, extract the following information:
|
| 19 |
+
- entity_name: Name of the entity, use same language as input text. If English, capitalized the name.
|
| 20 |
+
- entity_type: One of the following types: [{entity_types}]
|
| 21 |
+
- entity_description: Comprehensive description of the entity's attributes and activities
|
| 22 |
+
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>)
|
| 23 |
+
|
| 24 |
+
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
|
| 25 |
+
For each pair of related entities, extract the following information:
|
| 26 |
+
- source_entity: name of the source entity, as identified in step 1
|
| 27 |
+
- target_entity: name of the target entity, as identified in step 1
|
| 28 |
+
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
|
| 29 |
+
- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
|
| 30 |
+
- relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details
|
| 31 |
+
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_keywords>{tuple_delimiter}<relationship_strength>)
|
| 32 |
+
|
| 33 |
+
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
|
| 34 |
+
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
|
| 35 |
+
|
| 36 |
+
4. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
|
| 37 |
+
|
| 38 |
+
5. When finished, output {completion_delimiter}
|
| 39 |
+
|
| 40 |
+
######################
|
| 41 |
+
-Examples-
|
| 42 |
+
######################
|
| 43 |
+
{examples}
|
| 44 |
+
|
| 45 |
+
#############################
|
| 46 |
+
-Real Data-
|
| 47 |
+
######################
|
| 48 |
+
Entity_types: {entity_types}
|
| 49 |
+
Text: {input_text}
|
| 50 |
+
######################
|
| 51 |
+
Output:
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
PROMPTS["entity_extraction_examples"] = [
|
| 55 |
+
"""Example 1:
|
| 56 |
+
|
| 57 |
+
Entity_types: [person, technology, mission, organization, location]
|
| 58 |
+
Text:
|
| 59 |
+
while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
|
| 60 |
+
|
| 61 |
+
Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.”
|
| 62 |
+
|
| 63 |
+
The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce.
|
| 64 |
+
|
| 65 |
+
It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
|
| 66 |
+
################
|
| 67 |
+
Output:
|
| 68 |
+
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter}
|
| 69 |
+
("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter}
|
| 70 |
+
("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter}
|
| 71 |
+
("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter}
|
| 72 |
+
("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter}
|
| 73 |
+
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}"power dynamics, perspective shift"{tuple_delimiter}7){record_delimiter}
|
| 74 |
+
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}"shared goals, rebellion"{tuple_delimiter}6){record_delimiter}
|
| 75 |
+
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}"conflict resolution, mutual respect"{tuple_delimiter}8){record_delimiter}
|
| 76 |
+
("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter}
|
| 77 |
+
("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter}
|
| 78 |
+
("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter}
|
| 79 |
+
#############################""",
|
| 80 |
+
"""Example 2:
|
| 81 |
+
|
| 82 |
+
Entity_types: [person, technology, mission, organization, location]
|
| 83 |
+
Text:
|
| 84 |
+
They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve.
|
| 85 |
+
|
| 86 |
+
Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril.
|
| 87 |
+
|
| 88 |
+
Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly
|
| 89 |
+
#############
|
| 90 |
+
Output:
|
| 91 |
+
("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter}
|
| 92 |
+
("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter}
|
| 93 |
+
("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter}
|
| 94 |
+
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}"decision-making, external influence"{tuple_delimiter}7){record_delimiter}
|
| 95 |
+
("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter}
|
| 96 |
+
("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter}
|
| 97 |
+
#############################""",
|
| 98 |
+
"""Example 3:
|
| 99 |
+
|
| 100 |
+
Entity_types: [person, role, technology, organization, event, location, concept]
|
| 101 |
+
Text:
|
| 102 |
+
their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data.
|
| 103 |
+
|
| 104 |
+
"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning."
|
| 105 |
+
|
| 106 |
+
Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back."
|
| 107 |
+
|
| 108 |
+
Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history.
|
| 109 |
+
|
| 110 |
+
The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation
|
| 111 |
+
#############
|
| 112 |
+
Output:
|
| 113 |
+
("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter}
|
| 114 |
+
("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter}
|
| 115 |
+
("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter}
|
| 116 |
+
("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter}
|
| 117 |
+
("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter}
|
| 118 |
+
("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter}
|
| 119 |
+
("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}"communication, learning process"{tuple_delimiter}9){record_delimiter}
|
| 120 |
+
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}"leadership, exploration"{tuple_delimiter}10){record_delimiter}
|
| 121 |
+
("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter}
|
| 122 |
+
("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter}
|
| 123 |
+
("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter}
|
| 124 |
+
#############################""",
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
PROMPTS[
|
| 128 |
+
"summarize_entity_descriptions"
|
| 129 |
+
] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
|
| 130 |
+
Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
|
| 131 |
+
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
|
| 132 |
+
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
|
| 133 |
+
Make sure it is written in third person, and include the entity names so we the have full context.
|
| 134 |
+
Use {language} as output language.
|
| 135 |
+
|
| 136 |
+
#######
|
| 137 |
+
-Data-
|
| 138 |
+
Entities: {entity_name}
|
| 139 |
+
Description List: {description_list}
|
| 140 |
+
#######
|
| 141 |
+
Output:
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
PROMPTS[
|
| 145 |
+
"entiti_continue_extraction"
|
| 146 |
+
] = """MANY entities were missed in the last extraction. Add them below using the same format:
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
PROMPTS[
|
| 150 |
+
"entiti_if_loop_extraction"
|
| 151 |
+
] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
|
| 155 |
+
|
| 156 |
+
PROMPTS["rag_response"] = """---Role---
|
| 157 |
+
|
| 158 |
+
You are a helpful assistant responding to questions about data in the tables provided.
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
---Goal---
|
| 162 |
+
|
| 163 |
+
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
|
| 164 |
+
If you don't know the answer, just say so. Do not make anything up.
|
| 165 |
+
Do not include information where the supporting evidence for it is not provided.
|
| 166 |
+
|
| 167 |
+
---Target response length and format---
|
| 168 |
+
|
| 169 |
+
{response_type}
|
| 170 |
+
|
| 171 |
+
---Data tables---
|
| 172 |
+
|
| 173 |
+
{context_data}
|
| 174 |
+
|
| 175 |
+
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
PROMPTS["keywords_extraction"] = """---Role---
|
| 179 |
+
|
| 180 |
+
You are a helpful assistant tasked with identifying both high-level and low-level keywords in the user's query.
|
| 181 |
+
|
| 182 |
+
---Goal---
|
| 183 |
+
|
| 184 |
+
Given the query, list both high-level and low-level keywords. High-level keywords focus on overarching concepts or themes, while low-level keywords focus on specific entities, details, or concrete terms.
|
| 185 |
+
|
| 186 |
+
---Instructions---
|
| 187 |
+
|
| 188 |
+
- Output the keywords in JSON format.
|
| 189 |
+
- The JSON should have two keys:
|
| 190 |
+
- "high_level_keywords" for overarching concepts or themes.
|
| 191 |
+
- "low_level_keywords" for specific entities or details.
|
| 192 |
+
|
| 193 |
+
######################
|
| 194 |
+
-Examples-
|
| 195 |
+
######################
|
| 196 |
+
{examples}
|
| 197 |
+
|
| 198 |
+
#############################
|
| 199 |
+
-Real Data-
|
| 200 |
+
######################
|
| 201 |
+
Query: {query}
|
| 202 |
+
######################
|
| 203 |
+
The `Output` should be human text, not unicode characters. Keep the same language as `Query`.
|
| 204 |
+
Output:
|
| 205 |
+
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
PROMPTS["keywords_extraction_examples"] = [
|
| 209 |
+
"""Example 1:
|
| 210 |
+
|
| 211 |
+
Query: "How does international trade influence global economic stability?"
|
| 212 |
+
################
|
| 213 |
+
Output:
|
| 214 |
+
{{
|
| 215 |
+
"high_level_keywords": ["International trade", "Global economic stability", "Economic impact"],
|
| 216 |
+
"low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
|
| 217 |
+
}}
|
| 218 |
+
#############################""",
|
| 219 |
+
"""Example 2:
|
| 220 |
+
|
| 221 |
+
Query: "What are the environmental consequences of deforestation on biodiversity?"
|
| 222 |
+
################
|
| 223 |
+
Output:
|
| 224 |
+
{{
|
| 225 |
+
"high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"],
|
| 226 |
+
"low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"]
|
| 227 |
+
}}
|
| 228 |
+
#############################""",
|
| 229 |
+
"""Example 3:
|
| 230 |
+
|
| 231 |
+
Query: "What is the role of education in reducing poverty?"
|
| 232 |
+
################
|
| 233 |
+
Output:
|
| 234 |
+
{{
|
| 235 |
+
"high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"],
|
| 236 |
+
"low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"]
|
| 237 |
+
}}
|
| 238 |
+
#############################""",
|
| 239 |
+
]
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
PROMPTS["naive_rag_response"] = """---Role---
|
| 243 |
+
|
| 244 |
+
You are a helpful assistant responding to questions about documents provided.
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
---Goal---
|
| 248 |
+
|
| 249 |
+
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
|
| 250 |
+
If you don't know the answer, just say so. Do not make anything up.
|
| 251 |
+
Do not include information where the supporting evidence for it is not provided.
|
| 252 |
+
|
| 253 |
+
---Target response length and format---
|
| 254 |
+
|
| 255 |
+
{response_type}
|
| 256 |
+
|
| 257 |
+
---Documents---
|
| 258 |
+
|
| 259 |
+
{content_data}
|
| 260 |
+
|
| 261 |
+
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
PROMPTS[
|
| 265 |
+
"similarity_check"
|
| 266 |
+
] = """Please analyze the similarity between these two questions:
|
| 267 |
+
|
| 268 |
+
Question 1: {original_prompt}
|
| 269 |
+
Question 2: {cached_prompt}
|
| 270 |
+
|
| 271 |
+
Please evaluate the following two points and provide a similarity score between 0 and 1 directly:
|
| 272 |
+
1. Whether these two questions are semantically similar
|
| 273 |
+
2. Whether the answer to Question 2 can be used to answer Question 1
|
| 274 |
+
Similarity score criteria:
|
| 275 |
+
0: Completely unrelated or answer cannot be reused, including but not limited to:
|
| 276 |
+
- The questions have different topics
|
| 277 |
+
- The locations mentioned in the questions are different
|
| 278 |
+
- The times mentioned in the questions are different
|
| 279 |
+
- The specific individuals mentioned in the questions are different
|
| 280 |
+
- The specific events mentioned in the questions are different
|
| 281 |
+
- The background information in the questions is different
|
| 282 |
+
- The key conditions in the questions are different
|
| 283 |
+
1: Identical and answer can be directly reused
|
| 284 |
+
0.5: Partially related and answer needs modification to be used
|
| 285 |
+
Return only a number between 0-1, without any additional content.
|
| 286 |
+
"""
|
PathRAG/storage.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import html
|
| 3 |
+
import os
|
| 4 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, Union, cast
|
| 7 |
+
import networkx as nx
|
| 8 |
+
import numpy as np
|
| 9 |
+
from nano_vectordb import NanoVectorDB
|
| 10 |
+
|
| 11 |
+
from .utils import (
|
| 12 |
+
logger,
|
| 13 |
+
load_json,
|
| 14 |
+
write_json,
|
| 15 |
+
compute_mdhash_id,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
from .base import (
|
| 19 |
+
BaseGraphStorage,
|
| 20 |
+
BaseKVStorage,
|
| 21 |
+
BaseVectorStorage,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class JsonKVStorage(BaseKVStorage):
|
| 27 |
+
def __post_init__(self):
|
| 28 |
+
working_dir = self.global_config["working_dir"]
|
| 29 |
+
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
|
| 30 |
+
self._data = load_json(self._file_name) or {}
|
| 31 |
+
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
| 32 |
+
|
| 33 |
+
async def all_keys(self) -> list[str]:
|
| 34 |
+
return list(self._data.keys())
|
| 35 |
+
|
| 36 |
+
async def index_done_callback(self):
|
| 37 |
+
write_json(self._data, self._file_name)
|
| 38 |
+
|
| 39 |
+
async def get_by_id(self, id):
|
| 40 |
+
return self._data.get(id, None)
|
| 41 |
+
|
| 42 |
+
async def get_by_ids(self, ids, fields=None):
|
| 43 |
+
if fields is None:
|
| 44 |
+
return [self._data.get(id, None) for id in ids]
|
| 45 |
+
return [
|
| 46 |
+
(
|
| 47 |
+
{k: v for k, v in self._data[id].items() if k in fields}
|
| 48 |
+
if self._data.get(id, None)
|
| 49 |
+
else None
|
| 50 |
+
)
|
| 51 |
+
for id in ids
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
async def filter_keys(self, data: list[str]) -> set[str]:
|
| 55 |
+
return set([s for s in data if s not in self._data])
|
| 56 |
+
|
| 57 |
+
async def upsert(self, data: dict[str, dict]):
|
| 58 |
+
left_data = {k: v for k, v in data.items() if k not in self._data}
|
| 59 |
+
self._data.update(left_data)
|
| 60 |
+
return left_data
|
| 61 |
+
|
| 62 |
+
async def drop(self):
|
| 63 |
+
self._data = {}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class NanoVectorDBStorage(BaseVectorStorage):
|
| 68 |
+
cosine_better_than_threshold: float = 0.2
|
| 69 |
+
|
| 70 |
+
def __post_init__(self):
|
| 71 |
+
self._client_file_name = os.path.join(
|
| 72 |
+
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
| 73 |
+
)
|
| 74 |
+
self._max_batch_size = self.global_config["embedding_batch_num"]
|
| 75 |
+
self._client = NanoVectorDB(
|
| 76 |
+
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
| 77 |
+
)
|
| 78 |
+
self.cosine_better_than_threshold = self.global_config.get(
|
| 79 |
+
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
async def upsert(self, data: dict[str, dict]):
|
| 83 |
+
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
| 84 |
+
if not len(data):
|
| 85 |
+
logger.warning("You insert an empty data to vector DB")
|
| 86 |
+
return []
|
| 87 |
+
list_data = [
|
| 88 |
+
{
|
| 89 |
+
"__id__": k,
|
| 90 |
+
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
| 91 |
+
}
|
| 92 |
+
for k, v in data.items()
|
| 93 |
+
]
|
| 94 |
+
contents = [v["content"] for v in data.values()]
|
| 95 |
+
batches = [
|
| 96 |
+
contents[i : i + self._max_batch_size]
|
| 97 |
+
for i in range(0, len(contents), self._max_batch_size)
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
async def wrapped_task(batch):
|
| 101 |
+
result = await self.embedding_func(batch)
|
| 102 |
+
pbar.update(1)
|
| 103 |
+
return result
|
| 104 |
+
|
| 105 |
+
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
| 106 |
+
pbar = tqdm_async(
|
| 107 |
+
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
|
| 108 |
+
)
|
| 109 |
+
embeddings_list = await asyncio.gather(*embedding_tasks)
|
| 110 |
+
|
| 111 |
+
embeddings = np.concatenate(embeddings_list)
|
| 112 |
+
if len(embeddings) == len(list_data):
|
| 113 |
+
for i, d in enumerate(list_data):
|
| 114 |
+
d["__vector__"] = embeddings[i]
|
| 115 |
+
results = self._client.upsert(datas=list_data)
|
| 116 |
+
return results
|
| 117 |
+
else:
|
| 118 |
+
# sometimes the embedding is not returned correctly. just log it.
|
| 119 |
+
logger.error(
|
| 120 |
+
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
async def query(self, query: str, top_k=5):
|
| 124 |
+
embedding = await self.embedding_func([query])
|
| 125 |
+
embedding = embedding[0]
|
| 126 |
+
results = self._client.query(
|
| 127 |
+
query=embedding,
|
| 128 |
+
top_k=top_k,
|
| 129 |
+
better_than_threshold=self.cosine_better_than_threshold,
|
| 130 |
+
)
|
| 131 |
+
results = [
|
| 132 |
+
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
|
| 133 |
+
]
|
| 134 |
+
return results
|
| 135 |
+
|
| 136 |
+
@property
|
| 137 |
+
def client_storage(self):
|
| 138 |
+
return getattr(self._client, "_NanoVectorDB__storage")
|
| 139 |
+
|
| 140 |
+
async def delete_entity(self, entity_name: str):
|
| 141 |
+
try:
|
| 142 |
+
entity_id = [compute_mdhash_id(entity_name, prefix="ent-")]
|
| 143 |
+
|
| 144 |
+
if self._client.get(entity_id):
|
| 145 |
+
self._client.delete(entity_id)
|
| 146 |
+
logger.info(f"Entity {entity_name} have been deleted.")
|
| 147 |
+
else:
|
| 148 |
+
logger.info(f"No entity found with name {entity_name}.")
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.error(f"Error while deleting entity {entity_name}: {e}")
|
| 151 |
+
|
| 152 |
+
async def delete_relation(self, entity_name: str):
|
| 153 |
+
try:
|
| 154 |
+
relations = [
|
| 155 |
+
dp
|
| 156 |
+
for dp in self.client_storage["data"]
|
| 157 |
+
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
| 158 |
+
]
|
| 159 |
+
ids_to_delete = [relation["__id__"] for relation in relations]
|
| 160 |
+
|
| 161 |
+
if ids_to_delete:
|
| 162 |
+
self._client.delete(ids_to_delete)
|
| 163 |
+
logger.info(
|
| 164 |
+
f"All relations related to entity {entity_name} have been deleted."
|
| 165 |
+
)
|
| 166 |
+
else:
|
| 167 |
+
logger.info(f"No relations found for entity {entity_name}.")
|
| 168 |
+
except Exception as e:
|
| 169 |
+
logger.error(
|
| 170 |
+
f"Error while deleting relations for entity {entity_name}: {e}"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
async def index_done_callback(self):
|
| 174 |
+
self._client.save()
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@dataclass
|
| 178 |
+
class NetworkXStorage(BaseGraphStorage):
|
| 179 |
+
@staticmethod
|
| 180 |
+
def load_nx_graph(file_name) -> nx.DiGraph:
|
| 181 |
+
if os.path.exists(file_name):
|
| 182 |
+
return nx.read_graphml(file_name)
|
| 183 |
+
return None
|
| 184 |
+
# def load_nx_graph(file_name) -> nx.Graph:
|
| 185 |
+
# if os.path.exists(file_name):
|
| 186 |
+
# return nx.read_graphml(file_name)
|
| 187 |
+
# return None
|
| 188 |
+
|
| 189 |
+
@staticmethod
|
| 190 |
+
def write_nx_graph(graph: nx.DiGraph, file_name):
|
| 191 |
+
logger.info(
|
| 192 |
+
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
|
| 193 |
+
)
|
| 194 |
+
nx.write_graphml(graph, file_name)
|
| 195 |
+
|
| 196 |
+
@staticmethod
|
| 197 |
+
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
|
| 198 |
+
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
| 199 |
+
Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
|
| 200 |
+
"""
|
| 201 |
+
from graspologic.utils import largest_connected_component
|
| 202 |
+
|
| 203 |
+
graph = graph.copy()
|
| 204 |
+
graph = cast(nx.Graph, largest_connected_component(graph))
|
| 205 |
+
node_mapping = {
|
| 206 |
+
node: html.unescape(node.upper().strip()) for node in graph.nodes()
|
| 207 |
+
} # type: ignore
|
| 208 |
+
graph = nx.relabel_nodes(graph, node_mapping)
|
| 209 |
+
return NetworkXStorage._stabilize_graph(graph)
|
| 210 |
+
|
| 211 |
+
@staticmethod
|
| 212 |
+
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
|
| 213 |
+
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
|
| 214 |
+
Ensure an undirected graph with the same relationships will always be read the same way.
|
| 215 |
+
"""
|
| 216 |
+
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
|
| 217 |
+
|
| 218 |
+
sorted_nodes = graph.nodes(data=True)
|
| 219 |
+
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
|
| 220 |
+
|
| 221 |
+
fixed_graph.add_nodes_from(sorted_nodes)
|
| 222 |
+
edges = list(graph.edges(data=True))
|
| 223 |
+
|
| 224 |
+
if not graph.is_directed():
|
| 225 |
+
|
| 226 |
+
def _sort_source_target(edge):
|
| 227 |
+
source, target, edge_data = edge
|
| 228 |
+
if source > target:
|
| 229 |
+
temp = source
|
| 230 |
+
source = target
|
| 231 |
+
target = temp
|
| 232 |
+
return source, target, edge_data
|
| 233 |
+
|
| 234 |
+
edges = [_sort_source_target(edge) for edge in edges]
|
| 235 |
+
|
| 236 |
+
def _get_edge_key(source: Any, target: Any) -> str:
|
| 237 |
+
return f"{source} -> {target}"
|
| 238 |
+
|
| 239 |
+
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
|
| 240 |
+
|
| 241 |
+
fixed_graph.add_edges_from(edges)
|
| 242 |
+
return fixed_graph
|
| 243 |
+
|
| 244 |
+
def __post_init__(self):
|
| 245 |
+
self._graphml_xml_file = os.path.join(
|
| 246 |
+
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
| 247 |
+
)
|
| 248 |
+
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
| 249 |
+
if preloaded_graph is not None:
|
| 250 |
+
logger.info(
|
| 251 |
+
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
| 252 |
+
)
|
| 253 |
+
self._graph = preloaded_graph or nx.DiGraph()
|
| 254 |
+
self._node_embed_algorithms = {
|
| 255 |
+
"node2vec": self._node2vec_embed,
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
async def index_done_callback(self):
|
| 259 |
+
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
|
| 260 |
+
|
| 261 |
+
async def has_node(self, node_id: str) -> bool:
|
| 262 |
+
return self._graph.has_node(node_id)
|
| 263 |
+
|
| 264 |
+
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
| 265 |
+
return self._graph.has_edge(source_node_id, target_node_id)
|
| 266 |
+
|
| 267 |
+
async def get_node(self, node_id: str) -> Union[dict, None]:
|
| 268 |
+
return self._graph.nodes.get(node_id)
|
| 269 |
+
|
| 270 |
+
async def node_degree(self, node_id: str) -> int:
|
| 271 |
+
return self._graph.degree(node_id)
|
| 272 |
+
|
| 273 |
+
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 274 |
+
return self._graph.degree(src_id) + self._graph.degree(tgt_id)
|
| 275 |
+
|
| 276 |
+
async def get_edge(
|
| 277 |
+
self, source_node_id: str, target_node_id: str
|
| 278 |
+
) -> Union[dict, None]:
|
| 279 |
+
return self._graph.edges.get((source_node_id, target_node_id))
|
| 280 |
+
|
| 281 |
+
async def get_node_edges(self, source_node_id: str):
|
| 282 |
+
if self._graph.has_node(source_node_id):
|
| 283 |
+
return list(self._graph.edges(source_node_id))
|
| 284 |
+
return None
|
| 285 |
+
async def get_node_in_edges(self, source_node_id: str):
|
| 286 |
+
if self._graph.has_node(source_node_id):
|
| 287 |
+
return list(self._graph.in_edges(source_node_id))
|
| 288 |
+
return None
|
| 289 |
+
async def get_node_out_edges(self, source_node_id: str):
|
| 290 |
+
if self._graph.has_node(source_node_id):
|
| 291 |
+
return list(self._graph.out_edges(source_node_id))
|
| 292 |
+
return None
|
| 293 |
+
|
| 294 |
+
async def get_pagerank(self,source_node_id:str):
|
| 295 |
+
pagerank_list=nx.pagerank(self._graph)
|
| 296 |
+
if source_node_id in pagerank_list:
|
| 297 |
+
return pagerank_list[source_node_id]
|
| 298 |
+
else:
|
| 299 |
+
print("pagerank failed")
|
| 300 |
+
|
| 301 |
+
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 302 |
+
self._graph.add_node(node_id, **node_data)
|
| 303 |
+
|
| 304 |
+
async def upsert_edge(
|
| 305 |
+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 306 |
+
):
|
| 307 |
+
self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
| 308 |
+
|
| 309 |
+
async def delete_node(self, node_id: str):
|
| 310 |
+
"""
|
| 311 |
+
Delete a node from the graph based on the specified node_id.
|
| 312 |
+
|
| 313 |
+
:param node_id: The node_id to delete
|
| 314 |
+
"""
|
| 315 |
+
if self._graph.has_node(node_id):
|
| 316 |
+
self._graph.remove_node(node_id)
|
| 317 |
+
logger.info(f"Node {node_id} deleted from the graph.")
|
| 318 |
+
else:
|
| 319 |
+
logger.warning(f"Node {node_id} not found in the graph for deletion.")
|
| 320 |
+
|
| 321 |
+
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
| 322 |
+
if algorithm not in self._node_embed_algorithms:
|
| 323 |
+
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
|
| 324 |
+
return await self._node_embed_algorithms[algorithm]()
|
| 325 |
+
|
| 326 |
+
# @TODO: NOT USED
|
| 327 |
+
async def _node2vec_embed(self):
|
| 328 |
+
from graspologic import embed
|
| 329 |
+
|
| 330 |
+
embeddings, nodes = embed.node2vec_embed(
|
| 331 |
+
self._graph,
|
| 332 |
+
**self.global_config["node2vec_params"],
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
| 336 |
+
return embeddings, nodes_ids
|
| 337 |
+
|
| 338 |
+
async def edges(self):
|
| 339 |
+
return self._graph.edges()
|
| 340 |
+
async def nodes(self):
|
| 341 |
+
return self._graph.nodes()
|
PathRAG/utils.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import html
|
| 3 |
+
import io
|
| 4 |
+
import csv
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from functools import wraps
|
| 11 |
+
from hashlib import md5
|
| 12 |
+
from typing import Any, Union, List, Optional
|
| 13 |
+
import xml.etree.ElementTree as ET
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import tiktoken
|
| 17 |
+
|
| 18 |
+
from PathRAG.prompt import PROMPTS
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class UnlimitedSemaphore:
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
async def __aenter__(self):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
ENCODER = None
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger("PathRAG")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def set_logger(log_file: str):
|
| 37 |
+
logger.setLevel(logging.DEBUG)
|
| 38 |
+
|
| 39 |
+
file_handler = logging.FileHandler(log_file)
|
| 40 |
+
file_handler.setLevel(logging.DEBUG)
|
| 41 |
+
|
| 42 |
+
formatter = logging.Formatter(
|
| 43 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 44 |
+
)
|
| 45 |
+
file_handler.setFormatter(formatter)
|
| 46 |
+
|
| 47 |
+
if not logger.handlers:
|
| 48 |
+
logger.addHandler(file_handler)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class EmbeddingFunc:
|
| 53 |
+
embedding_dim: int
|
| 54 |
+
max_token_size: int
|
| 55 |
+
func: callable
|
| 56 |
+
concurrent_limit: int = 16
|
| 57 |
+
|
| 58 |
+
def __post_init__(self):
|
| 59 |
+
if self.concurrent_limit != 0:
|
| 60 |
+
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
|
| 61 |
+
else:
|
| 62 |
+
self._semaphore = UnlimitedSemaphore()
|
| 63 |
+
|
| 64 |
+
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
| 65 |
+
async with self._semaphore:
|
| 66 |
+
return await self.func(*args, **kwargs)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
| 73 |
+
if maybe_json_str is not None:
|
| 74 |
+
maybe_json_str = maybe_json_str.group(0)
|
| 75 |
+
maybe_json_str = maybe_json_str.replace("\\n", "")
|
| 76 |
+
maybe_json_str = maybe_json_str.replace("\n", "")
|
| 77 |
+
maybe_json_str = maybe_json_str.replace("'", '"')
|
| 78 |
+
|
| 79 |
+
return maybe_json_str
|
| 80 |
+
except Exception:
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def convert_response_to_json(response: str) -> dict:
|
| 88 |
+
json_str = locate_json_string_body_from_string(response)
|
| 89 |
+
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
| 90 |
+
try:
|
| 91 |
+
data = json.loads(json_str)
|
| 92 |
+
return data
|
| 93 |
+
except json.JSONDecodeError as e:
|
| 94 |
+
logger.error(f"Failed to parse JSON: {json_str}")
|
| 95 |
+
raise e from None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def compute_args_hash(*args):
|
| 99 |
+
return md5(str(args).encode()).hexdigest()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def compute_mdhash_id(content, prefix: str = ""):
|
| 103 |
+
return prefix + md5(content.encode()).hexdigest()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def final_decro(func):
|
| 110 |
+
|
| 111 |
+
__current_size = 0
|
| 112 |
+
|
| 113 |
+
@wraps(func)
|
| 114 |
+
async def wait_func(*args, **kwargs):
|
| 115 |
+
nonlocal __current_size
|
| 116 |
+
while __current_size >= max_size:
|
| 117 |
+
await asyncio.sleep(waitting_time)
|
| 118 |
+
__current_size += 1
|
| 119 |
+
result = await func(*args, **kwargs)
|
| 120 |
+
__current_size -= 1
|
| 121 |
+
return result
|
| 122 |
+
|
| 123 |
+
return wait_func
|
| 124 |
+
|
| 125 |
+
return final_decro
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def wrap_embedding_func_with_attrs(**kwargs):
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def final_decro(func) -> EmbeddingFunc:
|
| 132 |
+
new_func = EmbeddingFunc(**kwargs, func=func)
|
| 133 |
+
return new_func
|
| 134 |
+
|
| 135 |
+
return final_decro
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def load_json(file_name):
|
| 139 |
+
if not os.path.exists(file_name):
|
| 140 |
+
return None
|
| 141 |
+
with open(file_name, encoding="utf-8") as f:
|
| 142 |
+
return json.load(f)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def write_json(json_obj, file_name):
|
| 146 |
+
with open(file_name, "w", encoding="utf-8") as f:
|
| 147 |
+
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o-mini"):
|
| 151 |
+
global ENCODER
|
| 152 |
+
if ENCODER is None:
|
| 153 |
+
ENCODER = tiktoken.encoding_for_model(model_name)
|
| 154 |
+
tokens = ENCODER.encode(content)
|
| 155 |
+
return tokens
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o-mini"):
|
| 159 |
+
global ENCODER
|
| 160 |
+
if ENCODER is None:
|
| 161 |
+
ENCODER = tiktoken.encoding_for_model(model_name)
|
| 162 |
+
content = ENCODER.decode(tokens)
|
| 163 |
+
return content
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def pack_user_ass_to_openai_messages(*args: str):
|
| 167 |
+
roles = ["user", "assistant"]
|
| 168 |
+
return [
|
| 169 |
+
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
| 174 |
+
|
| 175 |
+
if not markers:
|
| 176 |
+
return [content]
|
| 177 |
+
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
| 178 |
+
return [r.strip() for r in results if r.strip()]
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def clean_str(input: Any) -> str:
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
if not isinstance(input, str):
|
| 186 |
+
return input
|
| 187 |
+
|
| 188 |
+
result = html.unescape(input.strip())
|
| 189 |
+
|
| 190 |
+
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def is_float_regex(value):
|
| 194 |
+
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
| 198 |
+
|
| 199 |
+
if max_token_size <= 0:
|
| 200 |
+
return []
|
| 201 |
+
tokens = 0
|
| 202 |
+
for i, data in enumerate(list_data):
|
| 203 |
+
tokens += len(encode_string_by_tiktoken(key(data)))
|
| 204 |
+
if tokens > max_token_size:
|
| 205 |
+
return list_data[:i]
|
| 206 |
+
return list_data
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def list_of_list_to_csv(data: List[List[str]]) -> str:
|
| 210 |
+
output = io.StringIO()
|
| 211 |
+
writer = csv.writer(output)
|
| 212 |
+
writer.writerows(data)
|
| 213 |
+
return output.getvalue()
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def csv_string_to_list(csv_string: str) -> List[List[str]]:
|
| 217 |
+
output = io.StringIO(csv_string)
|
| 218 |
+
reader = csv.reader(output)
|
| 219 |
+
return [row for row in reader]
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def save_data_to_file(data, file_name):
|
| 223 |
+
with open(file_name, "w", encoding="utf-8") as f:
|
| 224 |
+
json.dump(data, f, ensure_ascii=False, indent=4)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def xml_to_json(xml_file):
|
| 228 |
+
try:
|
| 229 |
+
tree = ET.parse(xml_file)
|
| 230 |
+
root = tree.getroot()
|
| 231 |
+
|
| 232 |
+
print(f"Root element: {root.tag}")
|
| 233 |
+
print(f"Root attributes: {root.attrib}")
|
| 234 |
+
|
| 235 |
+
data = {"nodes": [], "edges": []}
|
| 236 |
+
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
|
| 237 |
+
|
| 238 |
+
for node in root.findall(".//node", namespace):
|
| 239 |
+
node_data = {
|
| 240 |
+
"id": node.get("id").strip('"'),
|
| 241 |
+
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
|
| 242 |
+
if node.find("./data[@key='d0']", namespace) is not None
|
| 243 |
+
else "",
|
| 244 |
+
"description": node.find("./data[@key='d1']", namespace).text
|
| 245 |
+
if node.find("./data[@key='d1']", namespace) is not None
|
| 246 |
+
else "",
|
| 247 |
+
"source_id": node.find("./data[@key='d2']", namespace).text
|
| 248 |
+
if node.find("./data[@key='d2']", namespace) is not None
|
| 249 |
+
else "",
|
| 250 |
+
}
|
| 251 |
+
data["nodes"].append(node_data)
|
| 252 |
+
|
| 253 |
+
for edge in root.findall(".//edge", namespace):
|
| 254 |
+
edge_data = {
|
| 255 |
+
"source": edge.get("source").strip('"'),
|
| 256 |
+
"target": edge.get("target").strip('"'),
|
| 257 |
+
"weight": float(edge.find("./data[@key='d3']", namespace).text)
|
| 258 |
+
if edge.find("./data[@key='d3']", namespace) is not None
|
| 259 |
+
else 0.0,
|
| 260 |
+
"description": edge.find("./data[@key='d4']", namespace).text
|
| 261 |
+
if edge.find("./data[@key='d4']", namespace) is not None
|
| 262 |
+
else "",
|
| 263 |
+
"keywords": edge.find("./data[@key='d5']", namespace).text
|
| 264 |
+
if edge.find("./data[@key='d5']", namespace) is not None
|
| 265 |
+
else "",
|
| 266 |
+
"source_id": edge.find("./data[@key='d6']", namespace).text
|
| 267 |
+
if edge.find("./data[@key='d6']", namespace) is not None
|
| 268 |
+
else "",
|
| 269 |
+
}
|
| 270 |
+
data["edges"].append(edge_data)
|
| 271 |
+
|
| 272 |
+
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
|
| 273 |
+
|
| 274 |
+
return data
|
| 275 |
+
except ET.ParseError as e:
|
| 276 |
+
print(f"Error parsing XML file: {e}")
|
| 277 |
+
return None
|
| 278 |
+
except Exception as e:
|
| 279 |
+
print(f"An error occurred: {e}")
|
| 280 |
+
return None
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def process_combine_contexts(hl, ll):
|
| 284 |
+
header = None
|
| 285 |
+
list_hl = csv_string_to_list(hl.strip())
|
| 286 |
+
list_ll = csv_string_to_list(ll.strip())
|
| 287 |
+
|
| 288 |
+
if list_hl:
|
| 289 |
+
header = list_hl[0]
|
| 290 |
+
list_hl = list_hl[1:]
|
| 291 |
+
if list_ll:
|
| 292 |
+
header = list_ll[0]
|
| 293 |
+
list_ll = list_ll[1:]
|
| 294 |
+
if header is None:
|
| 295 |
+
return ""
|
| 296 |
+
|
| 297 |
+
if list_hl:
|
| 298 |
+
list_hl = [",".join(item[1:]) for item in list_hl if item]
|
| 299 |
+
if list_ll:
|
| 300 |
+
list_ll = [",".join(item[1:]) for item in list_ll if item]
|
| 301 |
+
|
| 302 |
+
combined_sources = []
|
| 303 |
+
seen = set()
|
| 304 |
+
|
| 305 |
+
for item in list_hl + list_ll:
|
| 306 |
+
if item and item not in seen:
|
| 307 |
+
combined_sources.append(item)
|
| 308 |
+
seen.add(item)
|
| 309 |
+
|
| 310 |
+
combined_sources_result = [",\t".join(header)]
|
| 311 |
+
|
| 312 |
+
for i, item in enumerate(combined_sources, start=1):
|
| 313 |
+
combined_sources_result.append(f"{i},\t{item}")
|
| 314 |
+
|
| 315 |
+
combined_sources_result = "\n".join(combined_sources_result)
|
| 316 |
+
|
| 317 |
+
return combined_sources_result
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
async def get_best_cached_response(
|
| 321 |
+
hashing_kv,
|
| 322 |
+
current_embedding,
|
| 323 |
+
similarity_threshold=0.95,
|
| 324 |
+
mode="default",
|
| 325 |
+
use_llm_check=False,
|
| 326 |
+
llm_func=None,
|
| 327 |
+
original_prompt=None,
|
| 328 |
+
) -> Union[str, None]:
|
| 329 |
+
|
| 330 |
+
mode_cache = await hashing_kv.get_by_id(mode)
|
| 331 |
+
if not mode_cache:
|
| 332 |
+
return None
|
| 333 |
+
|
| 334 |
+
best_similarity = -1
|
| 335 |
+
best_response = None
|
| 336 |
+
best_prompt = None
|
| 337 |
+
best_cache_id = None
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
for cache_id, cache_data in mode_cache.items():
|
| 341 |
+
if cache_data["embedding"] is None:
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
cached_quantized = np.frombuffer(
|
| 346 |
+
bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
|
| 347 |
+
).reshape(cache_data["embedding_shape"])
|
| 348 |
+
cached_embedding = dequantize_embedding(
|
| 349 |
+
cached_quantized,
|
| 350 |
+
cache_data["embedding_min"],
|
| 351 |
+
cache_data["embedding_max"],
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
similarity = cosine_similarity(current_embedding, cached_embedding)
|
| 355 |
+
if similarity > best_similarity:
|
| 356 |
+
best_similarity = similarity
|
| 357 |
+
best_response = cache_data["return"]
|
| 358 |
+
best_prompt = cache_data["original_prompt"]
|
| 359 |
+
best_cache_id = cache_id
|
| 360 |
+
|
| 361 |
+
if best_similarity > similarity_threshold:
|
| 362 |
+
|
| 363 |
+
if use_llm_check and llm_func and original_prompt and best_prompt:
|
| 364 |
+
compare_prompt = PROMPTS["similarity_check"].format(
|
| 365 |
+
original_prompt=original_prompt, cached_prompt=best_prompt
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
try:
|
| 369 |
+
llm_result = await llm_func(compare_prompt)
|
| 370 |
+
llm_result = llm_result.strip()
|
| 371 |
+
llm_similarity = float(llm_result)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
best_similarity = llm_similarity
|
| 375 |
+
if best_similarity < similarity_threshold:
|
| 376 |
+
log_data = {
|
| 377 |
+
"event": "llm_check_cache_rejected",
|
| 378 |
+
"original_question": original_prompt[:100] + "..."
|
| 379 |
+
if len(original_prompt) > 100
|
| 380 |
+
else original_prompt,
|
| 381 |
+
"cached_question": best_prompt[:100] + "..."
|
| 382 |
+
if len(best_prompt) > 100
|
| 383 |
+
else best_prompt,
|
| 384 |
+
"similarity_score": round(best_similarity, 4),
|
| 385 |
+
"threshold": similarity_threshold,
|
| 386 |
+
}
|
| 387 |
+
logger.info(json.dumps(log_data, ensure_ascii=False))
|
| 388 |
+
return None
|
| 389 |
+
except Exception as e:
|
| 390 |
+
logger.warning(f"LLM similarity check failed: {e}")
|
| 391 |
+
return None
|
| 392 |
+
|
| 393 |
+
prompt_display = (
|
| 394 |
+
best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
|
| 395 |
+
)
|
| 396 |
+
log_data = {
|
| 397 |
+
"event": "cache_hit",
|
| 398 |
+
"mode": mode,
|
| 399 |
+
"similarity": round(best_similarity, 4),
|
| 400 |
+
"cache_id": best_cache_id,
|
| 401 |
+
"original_prompt": prompt_display,
|
| 402 |
+
}
|
| 403 |
+
logger.info(json.dumps(log_data, ensure_ascii=False))
|
| 404 |
+
return best_response
|
| 405 |
+
return None
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def cosine_similarity(v1, v2):
|
| 409 |
+
|
| 410 |
+
dot_product = np.dot(v1, v2)
|
| 411 |
+
norm1 = np.linalg.norm(v1)
|
| 412 |
+
norm2 = np.linalg.norm(v2)
|
| 413 |
+
return dot_product / (norm1 * norm2)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple:
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
min_val = embedding.min()
|
| 420 |
+
max_val = embedding.max()
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
scale = (2**bits - 1) / (max_val - min_val)
|
| 424 |
+
quantized = np.round((embedding - min_val) * scale).astype(np.uint8)
|
| 425 |
+
|
| 426 |
+
return quantized, min_val, max_val
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def dequantize_embedding(
|
| 430 |
+
quantized: np.ndarray, min_val: float, max_val: float, bits=8
|
| 431 |
+
) -> np.ndarray:
|
| 432 |
+
|
| 433 |
+
scale = (max_val - min_val) / (2**bits - 1)
|
| 434 |
+
return (quantized * scale + min_val).astype(np.float32)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
|
| 438 |
+
|
| 439 |
+
if hashing_kv is None:
|
| 440 |
+
return None, None, None, None
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
if mode == "naive":
|
| 444 |
+
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
| 445 |
+
if args_hash in mode_cache:
|
| 446 |
+
return mode_cache[args_hash]["return"], None, None, None
|
| 447 |
+
return None, None, None, None
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
embedding_cache_config = hashing_kv.global_config.get(
|
| 451 |
+
"embedding_cache_config",
|
| 452 |
+
{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
|
| 453 |
+
)
|
| 454 |
+
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
| 455 |
+
use_llm_check = embedding_cache_config.get("use_llm_check", False)
|
| 456 |
+
|
| 457 |
+
quantized = min_val = max_val = None
|
| 458 |
+
if is_embedding_cache_enabled:
|
| 459 |
+
|
| 460 |
+
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
| 461 |
+
llm_model_func = hashing_kv.global_config.get("llm_model_func")
|
| 462 |
+
|
| 463 |
+
current_embedding = await embedding_model_func([prompt])
|
| 464 |
+
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
| 465 |
+
best_cached_response = await get_best_cached_response(
|
| 466 |
+
hashing_kv,
|
| 467 |
+
current_embedding[0],
|
| 468 |
+
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
| 469 |
+
mode=mode,
|
| 470 |
+
use_llm_check=use_llm_check,
|
| 471 |
+
llm_func=llm_model_func if use_llm_check else None,
|
| 472 |
+
original_prompt=prompt if use_llm_check else None,
|
| 473 |
+
)
|
| 474 |
+
if best_cached_response is not None:
|
| 475 |
+
return best_cached_response, None, None, None
|
| 476 |
+
else:
|
| 477 |
+
|
| 478 |
+
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
| 479 |
+
if args_hash in mode_cache:
|
| 480 |
+
return mode_cache[args_hash]["return"], None, None, None
|
| 481 |
+
|
| 482 |
+
return None, quantized, min_val, max_val
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
@dataclass
|
| 486 |
+
class CacheData:
|
| 487 |
+
args_hash: str
|
| 488 |
+
content: str
|
| 489 |
+
prompt: str
|
| 490 |
+
quantized: Optional[np.ndarray] = None
|
| 491 |
+
min_val: Optional[float] = None
|
| 492 |
+
max_val: Optional[float] = None
|
| 493 |
+
mode: str = "default"
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
| 497 |
+
if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
|
| 498 |
+
return
|
| 499 |
+
|
| 500 |
+
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
| 501 |
+
|
| 502 |
+
mode_cache[cache_data.args_hash] = {
|
| 503 |
+
"return": cache_data.content,
|
| 504 |
+
"embedding": cache_data.quantized.tobytes().hex()
|
| 505 |
+
if cache_data.quantized is not None
|
| 506 |
+
else None,
|
| 507 |
+
"embedding_shape": cache_data.quantized.shape
|
| 508 |
+
if cache_data.quantized is not None
|
| 509 |
+
else None,
|
| 510 |
+
"embedding_min": cache_data.min_val,
|
| 511 |
+
"embedding_max": cache_data.max_val,
|
| 512 |
+
"original_prompt": cache_data.prompt,
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def safe_unicode_decode(content):
|
| 519 |
+
unicode_escape_pattern = re.compile(r"\\u([0-9a-fA-F]{4})")
|
| 520 |
+
def replace_unicode_escape(match):
|
| 521 |
+
return chr(int(match.group(1), 16))
|
| 522 |
+
|
| 523 |
+
decoded_content = unicode_escape_pattern.sub(
|
| 524 |
+
replace_unicode_escape, content.decode("utf-8")
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
return decoded_content
|
PathRAG/v1_test.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from PathRAG import PathRAG, QueryParam
|
| 3 |
+
from PathRAG.llm import gpt_4o_mini_complete
|
| 4 |
+
import time
|
| 5 |
+
import openai
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
from openai import OpenAI
|
| 10 |
+
import sys
|
| 11 |
+
import re
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
|
| 14 |
+
WORKING_DIR = ""
|
| 15 |
+
|
| 16 |
+
api_key=""
|
| 17 |
+
os.environ["OPENAI_API_KEY"] = api_key
|
| 18 |
+
base_url="https://api.openai.com/v1"
|
| 19 |
+
os.environ["OPENAI_API_BASE"]=base_url
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if not os.path.exists(WORKING_DIR):
|
| 23 |
+
os.mkdir(WORKING_DIR)
|
| 24 |
+
|
| 25 |
+
rag = PathRAG(
|
| 26 |
+
working_dir=WORKING_DIR,
|
| 27 |
+
llm_model_func=gpt_4o_mini_complete,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
data_file=""
|
| 31 |
+
question=""
|
| 32 |
+
with open(data_file) as f:
|
| 33 |
+
rag.insert(f.read())
|
| 34 |
+
|
| 35 |
+
print(rag.query(question, param=QueryParam(mode="hybrid")))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
aioboto3
|
| 3 |
+
aiohttp
|
| 4 |
+
|
| 5 |
+
# database packages
|
| 6 |
+
graspologic
|
| 7 |
+
hnswlib
|
| 8 |
+
nano-vectordb
|
| 9 |
+
neo4j
|
| 10 |
+
networkx
|
| 11 |
+
ollama
|
| 12 |
+
openai
|
| 13 |
+
oracledb
|
| 14 |
+
psycopg[binary,pool]
|
| 15 |
+
pymilvus
|
| 16 |
+
pymongo
|
| 17 |
+
pymysql
|
| 18 |
+
pyvis
|
| 19 |
+
# lmdeploy[all]
|
| 20 |
+
sqlalchemy
|
| 21 |
+
tenacity
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# LLM packages
|
| 25 |
+
tiktoken
|
| 26 |
+
torch
|
| 27 |
+
transformers
|
| 28 |
+
xxhash
|