Merge branch 'main' of https://github.com/jin38324/LightRAG
Browse files- README.md +12 -5
- examples/lightrag_api_oracle_demo..py +2 -2
- examples/lightrag_azure_openai_demo.py +23 -38
- lightrag/__init__.py +1 -1
- lightrag/kg/neo4j_impl.py +1 -3
- lightrag/kg/oracle_impl.py +106 -165
- lightrag/lightrag.py +5 -4
- lightrag/llm.py +6 -2
- lightrag/operate.py +27 -19
- lightrag/utils.py +12 -6
- reproduce/Step_1.py +1 -1
README.md
CHANGED
@@ -8,7 +8,7 @@
|
|
8 |
<a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
|
9 |
<a href='https://youtu.be/oageL-1I0GE'><img src='https://badges.aleen42.com/src/youtube.svg'></a>
|
10 |
<a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
|
11 |
-
<a href='https://
|
12 |
</p>
|
13 |
<p>
|
14 |
<img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
|
@@ -16,27 +16,34 @@
|
|
16 |
<a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
|
17 |
<a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
|
18 |
</p>
|
|
|
|
|
|
|
|
|
19 |
|
20 |
This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
|
21 |

|
22 |
</div>
|
23 |
|
24 |
## 🎉 News
|
25 |
-
- [x] [2024.11.
|
|
|
26 |
- [x] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete-entity).
|
27 |
- [x] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
|
28 |
- [x] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
|
29 |
- [x] [2024.10.29]🎯📢LightRAG now supports multiple file types, including PDF, DOC, PPT, and CSV via `textract`.
|
30 |
- [x] [2024.10.20]🎯📢We’ve added a new feature to LightRAG: Graph Visualization.
|
31 |
- [x] [2024.10.18]🎯📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
|
32 |
-
- [x] [2024.10.17]🎯📢We have created a [Discord channel](https://discord.gg/
|
33 |
- [x] [2024.10.16]🎯📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
34 |
- [x] [2024.10.15]🎯📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
35 |
|
36 |
## Algorithm Flowchart
|
37 |
|
38 |
-
.
|
25 |

|
26 |
</div>
|
27 |
|
28 |
## 🎉 News
|
29 |
+
- [x] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author!
|
30 |
+
- [x] [2024.11.12]🎯📢LightRAG now supports [Oracle Database 23ai for all storage types (KV, vector, and graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py).
|
31 |
- [x] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete-entity).
|
32 |
- [x] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
|
33 |
- [x] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
|
34 |
- [x] [2024.10.29]🎯📢LightRAG now supports multiple file types, including PDF, DOC, PPT, and CSV via `textract`.
|
35 |
- [x] [2024.10.20]🎯📢We’ve added a new feature to LightRAG: Graph Visualization.
|
36 |
- [x] [2024.10.18]🎯📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
|
37 |
+
- [x] [2024.10.17]🎯📢We have created a [Discord channel](https://discord.gg/yF2MmDJyGJ)! Welcome to join for sharing and discussions! 🎉🎉
|
38 |
- [x] [2024.10.16]🎯📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
39 |
- [x] [2024.10.15]🎯📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
40 |
|
41 |
## Algorithm Flowchart
|
42 |
|
43 |
+

|
44 |
+
*Figure 1: LightRAG Indexing Flowchart*
|
45 |
+

|
46 |
+
*Figure 2: LightRAG Retrieval and Querying Flowchart*
|
47 |
|
48 |
## Install
|
49 |
|
examples/lightrag_api_oracle_demo..py
CHANGED
@@ -162,12 +162,12 @@ class Response(BaseModel):
|
|
162 |
|
163 |
# API routes
|
164 |
|
165 |
-
rag = None
|
166 |
|
167 |
@asynccontextmanager
|
168 |
async def lifespan(app: FastAPI):
|
169 |
global rag
|
170 |
-
rag = await init()
|
171 |
print("done!")
|
172 |
yield
|
173 |
|
|
|
162 |
|
163 |
# API routes
|
164 |
|
165 |
+
rag = None
|
166 |
|
167 |
@asynccontextmanager
|
168 |
async def lifespan(app: FastAPI):
|
169 |
global rag
|
170 |
+
rag = await init()
|
171 |
print("done!")
|
172 |
yield
|
173 |
|
examples/lightrag_azure_openai_demo.py
CHANGED
@@ -4,8 +4,8 @@ from lightrag import LightRAG, QueryParam
|
|
4 |
from lightrag.utils import EmbeddingFunc
|
5 |
import numpy as np
|
6 |
from dotenv import load_dotenv
|
7 |
-
import aiohttp
|
8 |
import logging
|
|
|
9 |
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
|
@@ -32,11 +32,11 @@ os.mkdir(WORKING_DIR)
|
|
32 |
async def llm_model_func(
|
33 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
34 |
) -> str:
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
|
41 |
messages = []
|
42 |
if system_prompt:
|
@@ -45,41 +45,26 @@ async def llm_model_func(
|
|
45 |
messages.extend(history_messages)
|
46 |
messages.append({"role": "user", "content": prompt})
|
47 |
|
48 |
-
|
49 |
-
"
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
async with session.post(endpoint, headers=headers, json=payload) as response:
|
57 |
-
if response.status != 200:
|
58 |
-
raise ValueError(
|
59 |
-
f"Request failed with status {response.status}: {await response.text()}"
|
60 |
-
)
|
61 |
-
result = await response.json()
|
62 |
-
return result["choices"][0]["message"]["content"]
|
63 |
|
64 |
|
65 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
async with session.post(endpoint, headers=headers, json=payload) as response:
|
76 |
-
if response.status != 200:
|
77 |
-
raise ValueError(
|
78 |
-
f"Request failed with status {response.status}: {await response.text()}"
|
79 |
-
)
|
80 |
-
result = await response.json()
|
81 |
-
embeddings = [item["embedding"] for item in result["data"]]
|
82 |
-
return np.array(embeddings)
|
83 |
|
84 |
|
85 |
async def test_funcs():
|
|
|
4 |
from lightrag.utils import EmbeddingFunc
|
5 |
import numpy as np
|
6 |
from dotenv import load_dotenv
|
|
|
7 |
import logging
|
8 |
+
from openai import AzureOpenAI
|
9 |
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
|
|
|
32 |
async def llm_model_func(
|
33 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
34 |
) -> str:
|
35 |
+
client = AzureOpenAI(
|
36 |
+
api_key=AZURE_OPENAI_API_KEY,
|
37 |
+
api_version=AZURE_OPENAI_API_VERSION,
|
38 |
+
azure_endpoint=AZURE_OPENAI_ENDPOINT,
|
39 |
+
)
|
40 |
|
41 |
messages = []
|
42 |
if system_prompt:
|
|
|
45 |
messages.extend(history_messages)
|
46 |
messages.append({"role": "user", "content": prompt})
|
47 |
|
48 |
+
chat_completion = client.chat.completions.create(
|
49 |
+
model=AZURE_OPENAI_DEPLOYMENT, # model = "deployment_name".
|
50 |
+
messages=messages,
|
51 |
+
temperature=kwargs.get("temperature", 0),
|
52 |
+
top_p=kwargs.get("top_p", 1),
|
53 |
+
n=kwargs.get("n", 1),
|
54 |
+
)
|
55 |
+
return chat_completion.choices[0].message.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
|
58 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
59 |
+
client = AzureOpenAI(
|
60 |
+
api_key=AZURE_OPENAI_API_KEY,
|
61 |
+
api_version=AZURE_EMBEDDING_API_VERSION,
|
62 |
+
azure_endpoint=AZURE_OPENAI_ENDPOINT,
|
63 |
+
)
|
64 |
+
embedding = client.embeddings.create(model=AZURE_EMBEDDING_DEPLOYMENT, input=texts)
|
65 |
+
|
66 |
+
embeddings = [item.embedding for item in embedding.data]
|
67 |
+
return np.array(embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
|
70 |
async def test_funcs():
|
lightrag/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
2 |
|
3 |
-
__version__ = "1.0.
|
4 |
__author__ = "Zirui Guo"
|
5 |
__url__ = "https://github.com/HKUDS/LightRAG"
|
|
|
1 |
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
2 |
|
3 |
+
__version__ = "1.0.1"
|
4 |
__author__ = "Zirui Guo"
|
5 |
__url__ = "https://github.com/HKUDS/LightRAG"
|
lightrag/kg/neo4j_impl.py
CHANGED
@@ -86,9 +86,6 @@ class Neo4JStorage(BaseGraphStorage):
|
|
86 |
)
|
87 |
return single_result["edgeExists"]
|
88 |
|
89 |
-
def close(self):
|
90 |
-
self._driver.close()
|
91 |
-
|
92 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
93 |
async with self._driver.session() as session:
|
94 |
entity_name_label = node_id.strip('"')
|
@@ -214,6 +211,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
214 |
neo4jExceptions.ServiceUnavailable,
|
215 |
neo4jExceptions.TransientError,
|
216 |
neo4jExceptions.WriteServiceUnavailable,
|
|
|
217 |
)
|
218 |
),
|
219 |
)
|
|
|
86 |
)
|
87 |
return single_result["edgeExists"]
|
88 |
|
|
|
|
|
|
|
89 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
90 |
async with self._driver.session() as session:
|
91 |
entity_name_label = node_id.strip('"')
|
|
|
211 |
neo4jExceptions.ServiceUnavailable,
|
212 |
neo4jExceptions.TransientError,
|
213 |
neo4jExceptions.WriteServiceUnavailable,
|
214 |
+
neo4jExceptions.ClientError,
|
215 |
)
|
216 |
),
|
217 |
)
|
lightrag/kg/oracle_impl.py
CHANGED
@@ -114,7 +114,7 @@ class OracleDB:
|
|
114 |
|
115 |
logger.info("Finished check all tables in Oracle database")
|
116 |
|
117 |
-
async def query(self, sql: str,
|
118 |
async with self.pool.acquire() as connection:
|
119 |
connection.inputtypehandler = self.input_type_handler
|
120 |
connection.outputtypehandler = self.output_type_handler
|
@@ -173,10 +173,11 @@ class OracleKVStorage(BaseKVStorage):
|
|
173 |
|
174 |
async def get_by_id(self, id: str) -> Union[dict, None]:
|
175 |
"""根据 id 获取 doc_full 数据."""
|
176 |
-
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
177 |
-
|
|
|
178 |
# print("get_by_id:"+SQL)
|
179 |
-
res = await self.db.query(SQL
|
180 |
if res:
|
181 |
data = res # {"data":res}
|
182 |
# print (data)
|
@@ -187,11 +188,11 @@ class OracleKVStorage(BaseKVStorage):
|
|
187 |
# Query by id
|
188 |
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
189 |
"""根据 id 获取 doc_chunks 数据"""
|
190 |
-
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
191 |
-
|
192 |
-
|
193 |
-
#print(
|
194 |
-
res = await self.db.query(SQL,
|
195 |
if res:
|
196 |
data = res # [{"data":i} for i in res]
|
197 |
# print(data)
|
@@ -201,16 +202,12 @@ class OracleKVStorage(BaseKVStorage):
|
|
201 |
|
202 |
async def filter_keys(self, keys: list[str]) -> set[str]:
|
203 |
"""过滤掉重复内容"""
|
204 |
-
SQL = SQL_TEMPLATES["filter_keys"].format(
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
logger.error(f"Oracle database error: {e}")
|
211 |
-
print(SQL)
|
212 |
-
print(params)
|
213 |
-
res = await self.db.query(SQL, params,multirows=True)
|
214 |
data = None
|
215 |
if res:
|
216 |
exist_keys = [key["id"] for key in res]
|
@@ -247,29 +244,27 @@ class OracleKVStorage(BaseKVStorage):
|
|
247 |
d["__vector__"] = embeddings[i]
|
248 |
# print(list_data)
|
249 |
for item in list_data:
|
250 |
-
merge_sql = SQL_TEMPLATES["merge_chunk"]
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
|
|
260 |
# print(merge_sql)
|
261 |
await self.db.execute(merge_sql, data)
|
262 |
|
263 |
if self.namespace == "full_docs":
|
264 |
for k, v in self._data.items():
|
265 |
# values.clear()
|
266 |
-
merge_sql = SQL_TEMPLATES["merge_doc_full"]
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
"content":v["content"],
|
271 |
-
"workspace":self.db.workspace
|
272 |
-
}
|
273 |
# print(merge_sql)
|
274 |
await self.db.execute(merge_sql, data)
|
275 |
return left_data
|
@@ -301,17 +296,18 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|
301 |
# 转换精度
|
302 |
dtype = str(embedding.dtype).upper()
|
303 |
dimension = embedding.shape[0]
|
304 |
-
embedding_string = "
|
305 |
-
|
306 |
-
SQL = SQL_TEMPLATES[self.namespace].format(
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
|
|
313 |
# print(SQL)
|
314 |
-
results = await self.db.query(SQL,
|
315 |
# print("vector search result:",results)
|
316 |
return results
|
317 |
|
@@ -346,18 +342,22 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
346 |
)
|
347 |
embeddings = np.concatenate(embeddings_list)
|
348 |
content_vector = embeddings[0]
|
349 |
-
merge_sql = SQL_TEMPLATES["merge_node"]
|
350 |
-
|
351 |
-
|
352 |
-
"name":entity_name,
|
353 |
-
"entity_type":entity_type,
|
354 |
-
"description":description,
|
355 |
-
"source_chunk_id":source_id,
|
356 |
-
"content":content,
|
357 |
-
"content_vector":content_vector
|
358 |
-
}
|
359 |
# print(merge_sql)
|
360 |
-
await self.db.execute(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
# self._graph.add_node(node_id, **node_data)
|
362 |
|
363 |
async def upsert_edge(
|
@@ -371,8 +371,6 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
371 |
keywords = edge_data["keywords"]
|
372 |
description = edge_data["description"]
|
373 |
source_chunk_id = edge_data["source_id"]
|
374 |
-
logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
|
375 |
-
|
376 |
content = keywords + source_name + target_name + description
|
377 |
contents = [content]
|
378 |
batches = [
|
@@ -384,20 +382,27 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
384 |
)
|
385 |
embeddings = np.concatenate(embeddings_list)
|
386 |
content_vector = embeddings[0]
|
387 |
-
merge_sql = SQL_TEMPLATES["merge_edge"]
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
"keywords":keywords,
|
394 |
-
"description":description,
|
395 |
-
"source_chunk_id":source_chunk_id,
|
396 |
-
"content":content,
|
397 |
-
"content_vector":content_vector
|
398 |
-
}
|
399 |
# print(merge_sql)
|
400 |
-
await self.db.execute(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
402 |
|
403 |
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
@@ -427,14 +432,12 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
427 |
#################### query method #################
|
428 |
async def has_node(self, node_id: str) -> bool:
|
429 |
"""根据节点id检查节点是否存在"""
|
430 |
-
SQL = SQL_TEMPLATES["has_node"]
|
431 |
-
|
432 |
-
|
433 |
-
"node_id":node_id
|
434 |
-
}
|
435 |
# print(SQL)
|
436 |
# print(self.db.workspace, node_id)
|
437 |
-
res = await self.db.query(SQL
|
438 |
if res:
|
439 |
# print("Node exist!",res)
|
440 |
return True
|
@@ -444,14 +447,13 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
444 |
|
445 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
446 |
"""根据源和目标节点id检查边是否存在"""
|
447 |
-
SQL = SQL_TEMPLATES["has_edge"]
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
}
|
453 |
# print(SQL)
|
454 |
-
res = await self.db.query(SQL
|
455 |
if res:
|
456 |
# print("Edge exist!",res)
|
457 |
return True
|
@@ -461,13 +463,11 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
461 |
|
462 |
async def node_degree(self, node_id: str) -> int:
|
463 |
"""根据节点id获取节点的度"""
|
464 |
-
SQL = SQL_TEMPLATES["node_degree"]
|
465 |
-
|
466 |
-
|
467 |
-
"node_id":node_id
|
468 |
-
}
|
469 |
# print(SQL)
|
470 |
-
res = await self.db.query(SQL
|
471 |
if res:
|
472 |
# print("Node degree",res["degree"])
|
473 |
return res["degree"]
|
@@ -483,14 +483,12 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
483 |
|
484 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
485 |
"""根据节点id获取节点数据"""
|
486 |
-
SQL = SQL_TEMPLATES["get_node"]
|
487 |
-
|
488 |
-
|
489 |
-
"node_id":node_id
|
490 |
-
}
|
491 |
# print(self.db.workspace, node_id)
|
492 |
# print(SQL)
|
493 |
-
res = await self.db.query(SQL
|
494 |
if res:
|
495 |
# print("Get node!",self.db.workspace, node_id,res)
|
496 |
return res
|
@@ -502,13 +500,12 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
502 |
self, source_node_id: str, target_node_id: str
|
503 |
) -> Union[dict, None]:
|
504 |
"""根据源和目标节点id获取边"""
|
505 |
-
SQL = SQL_TEMPLATES["get_edge"]
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
res = await self.db.query(SQL,params)
|
512 |
if res:
|
513 |
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
514 |
return res
|
@@ -519,12 +516,10 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
519 |
async def get_node_edges(self, source_node_id: str):
|
520 |
"""根据节点id获取节点的所有边"""
|
521 |
if await self.has_node(source_node_id):
|
522 |
-
SQL = SQL_TEMPLATES["get_node_edges"]
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
}
|
527 |
-
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
528 |
if res:
|
529 |
data = [(i["source_name"], i["target_name"]) for i in res]
|
530 |
# print("Get node edge!",self.db.workspace, source_node_id,data)
|
@@ -532,29 +527,7 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
532 |
else:
|
533 |
# print("Node Edge not exist!",self.db.workspace, source_node_id)
|
534 |
return []
|
535 |
-
|
536 |
-
async def get_all_nodes(self, limit: int):
|
537 |
-
"""查询所有节点"""
|
538 |
-
SQL = SQL_TEMPLATES["get_all_nodes"]
|
539 |
-
params = {"workspace":self.db.workspace, "limit":str(limit)}
|
540 |
-
res = await self.db.query(sql=SQL,params=params, multirows=True)
|
541 |
-
if res:
|
542 |
-
return res
|
543 |
|
544 |
-
async def get_all_edges(self, limit: int):
|
545 |
-
"""查询所有边"""
|
546 |
-
SQL = SQL_TEMPLATES["get_all_edges"]
|
547 |
-
params = {"workspace":self.db.workspace, "limit":str(limit)}
|
548 |
-
res = await self.db.query(sql=SQL,params=params, multirows=True)
|
549 |
-
if res:
|
550 |
-
return res
|
551 |
-
|
552 |
-
async def get_statistics(self):
|
553 |
-
SQL = SQL_TEMPLATES["get_statistics"]
|
554 |
-
params = {"workspace":self.db.workspace}
|
555 |
-
res = await self.db.query(sql=SQL,params=params, multirows=True)
|
556 |
-
if res:
|
557 |
-
return res
|
558 |
|
559 |
N_T = {
|
560 |
"full_docs": "LIGHTRAG_DOC_FULL",
|
@@ -726,37 +699,5 @@ SQL_TEMPLATES = {
|
|
726 |
ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
|
727 |
WHEN NOT MATCHED THEN
|
728 |
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
729 |
-
values (:
|
730 |
-
"get_all_nodes":"""WITH t0 AS (
|
731 |
-
SELECT name AS id, entity_type AS label, entity_type, description,
|
732 |
-
'["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
|
733 |
-
FROM lightrag_graph_nodes
|
734 |
-
WHERE workspace = :workspace
|
735 |
-
ORDER BY createtime DESC fetch first :limit rows only
|
736 |
-
), t1 AS (
|
737 |
-
SELECT t0.id, source_chunk_id
|
738 |
-
FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
|
739 |
-
), t2 AS (
|
740 |
-
SELECT t1.id, LISTAGG(t2.content, '\n') content
|
741 |
-
FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
|
742 |
-
GROUP BY t1.id
|
743 |
-
)
|
744 |
-
SELECT t0.id, label, entity_type, description, t2.content
|
745 |
-
FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
|
746 |
-
"get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
|
747 |
-
t1.weight,t1.DESCRIPTION,t2.content
|
748 |
-
FROM LIGHTRAG_GRAPH_EDGES t1
|
749 |
-
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
750 |
-
WHERE t1.workspace=:workspace
|
751 |
-
order by t1.CREATETIME DESC
|
752 |
-
fetch first :limit rows only""",
|
753 |
-
"get_statistics":"""select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
|
754 |
-
count(distinct CASE WHEN type='edge' THEN id END) as edges_count
|
755 |
-
FROM (
|
756 |
-
select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
|
757 |
-
MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
|
758 |
-
UNION
|
759 |
-
select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
|
760 |
-
MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
|
761 |
-
)""",
|
762 |
}
|
|
|
114 |
|
115 |
logger.info("Finished check all tables in Oracle database")
|
116 |
|
117 |
+
async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
|
118 |
async with self.pool.acquire() as connection:
|
119 |
connection.inputtypehandler = self.input_type_handler
|
120 |
connection.outputtypehandler = self.output_type_handler
|
|
|
173 |
|
174 |
async def get_by_id(self, id: str) -> Union[dict, None]:
|
175 |
"""根据 id 获取 doc_full 数据."""
|
176 |
+
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
|
177 |
+
workspace=self.db.workspace, id=id
|
178 |
+
)
|
179 |
# print("get_by_id:"+SQL)
|
180 |
+
res = await self.db.query(SQL)
|
181 |
if res:
|
182 |
data = res # {"data":res}
|
183 |
# print (data)
|
|
|
188 |
# Query by id
|
189 |
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
190 |
"""根据 id 获取 doc_chunks 数据"""
|
191 |
+
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
192 |
+
workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
|
193 |
+
)
|
194 |
+
# print("get_by_ids:"+SQL)
|
195 |
+
res = await self.db.query(SQL, multirows=True)
|
196 |
if res:
|
197 |
data = res # [{"data":i} for i in res]
|
198 |
# print(data)
|
|
|
202 |
|
203 |
async def filter_keys(self, keys: list[str]) -> set[str]:
|
204 |
"""过滤掉重复内容"""
|
205 |
+
SQL = SQL_TEMPLATES["filter_keys"].format(
|
206 |
+
table_name=N_T[self.namespace],
|
207 |
+
workspace=self.db.workspace,
|
208 |
+
ids=",".join([f"'{k}'" for k in keys]),
|
209 |
+
)
|
210 |
+
res = await self.db.query(SQL, multirows=True)
|
|
|
|
|
|
|
|
|
211 |
data = None
|
212 |
if res:
|
213 |
exist_keys = [key["id"] for key in res]
|
|
|
244 |
d["__vector__"] = embeddings[i]
|
245 |
# print(list_data)
|
246 |
for item in list_data:
|
247 |
+
merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
|
248 |
+
|
249 |
+
values = [
|
250 |
+
item["__id__"],
|
251 |
+
item["content"],
|
252 |
+
self.db.workspace,
|
253 |
+
item["tokens"],
|
254 |
+
item["chunk_order_index"],
|
255 |
+
item["full_doc_id"],
|
256 |
+
item["__vector__"],
|
257 |
+
]
|
258 |
# print(merge_sql)
|
259 |
await self.db.execute(merge_sql, data)
|
260 |
|
261 |
if self.namespace == "full_docs":
|
262 |
for k, v in self._data.items():
|
263 |
# values.clear()
|
264 |
+
merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
|
265 |
+
check_id=k,
|
266 |
+
)
|
267 |
+
values = [k, self._data[k]["content"], self.db.workspace]
|
|
|
|
|
|
|
268 |
# print(merge_sql)
|
269 |
await self.db.execute(merge_sql, data)
|
270 |
return left_data
|
|
|
296 |
# 转换精度
|
297 |
dtype = str(embedding.dtype).upper()
|
298 |
dimension = embedding.shape[0]
|
299 |
+
embedding_string = ", ".join(map(str, embedding.tolist()))
|
300 |
+
|
301 |
+
SQL = SQL_TEMPLATES[self.namespace].format(
|
302 |
+
embedding_string=embedding_string,
|
303 |
+
dimension=dimension,
|
304 |
+
dtype=dtype,
|
305 |
+
workspace=self.db.workspace,
|
306 |
+
top_k=top_k,
|
307 |
+
better_than_threshold=self.cosine_better_than_threshold,
|
308 |
+
)
|
309 |
# print(SQL)
|
310 |
+
results = await self.db.query(SQL, multirows=True)
|
311 |
# print("vector search result:",results)
|
312 |
return results
|
313 |
|
|
|
342 |
)
|
343 |
embeddings = np.concatenate(embeddings_list)
|
344 |
content_vector = embeddings[0]
|
345 |
+
merge_sql = SQL_TEMPLATES["merge_node"].format(
|
346 |
+
workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
|
347 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
# print(merge_sql)
|
349 |
+
await self.db.execute(
|
350 |
+
merge_sql,
|
351 |
+
[
|
352 |
+
self.db.workspace,
|
353 |
+
entity_name,
|
354 |
+
entity_type,
|
355 |
+
description,
|
356 |
+
source_id,
|
357 |
+
content,
|
358 |
+
content_vector,
|
359 |
+
],
|
360 |
+
)
|
361 |
# self._graph.add_node(node_id, **node_data)
|
362 |
|
363 |
async def upsert_edge(
|
|
|
371 |
keywords = edge_data["keywords"]
|
372 |
description = edge_data["description"]
|
373 |
source_chunk_id = edge_data["source_id"]
|
|
|
|
|
374 |
content = keywords + source_name + target_name + description
|
375 |
contents = [content]
|
376 |
batches = [
|
|
|
382 |
)
|
383 |
embeddings = np.concatenate(embeddings_list)
|
384 |
content_vector = embeddings[0]
|
385 |
+
merge_sql = SQL_TEMPLATES["merge_edge"].format(
|
386 |
+
workspace=self.db.workspace,
|
387 |
+
source_name=source_name,
|
388 |
+
target_name=target_name,
|
389 |
+
source_chunk_id=source_chunk_id,
|
390 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
# print(merge_sql)
|
392 |
+
await self.db.execute(
|
393 |
+
merge_sql,
|
394 |
+
[
|
395 |
+
self.db.workspace,
|
396 |
+
source_name,
|
397 |
+
target_name,
|
398 |
+
weight,
|
399 |
+
keywords,
|
400 |
+
description,
|
401 |
+
source_chunk_id,
|
402 |
+
content,
|
403 |
+
content_vector,
|
404 |
+
],
|
405 |
+
)
|
406 |
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
407 |
|
408 |
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
|
|
432 |
#################### query method #################
|
433 |
async def has_node(self, node_id: str) -> bool:
|
434 |
"""根据节点id检查节点是否存在"""
|
435 |
+
SQL = SQL_TEMPLATES["has_node"].format(
|
436 |
+
workspace=self.db.workspace, node_id=node_id
|
437 |
+
)
|
|
|
|
|
438 |
# print(SQL)
|
439 |
# print(self.db.workspace, node_id)
|
440 |
+
res = await self.db.query(SQL)
|
441 |
if res:
|
442 |
# print("Node exist!",res)
|
443 |
return True
|
|
|
447 |
|
448 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
449 |
"""根据源和目标节点id检查边是否存在"""
|
450 |
+
SQL = SQL_TEMPLATES["has_edge"].format(
|
451 |
+
workspace=self.db.workspace,
|
452 |
+
source_node_id=source_node_id,
|
453 |
+
target_node_id=target_node_id,
|
454 |
+
)
|
|
|
455 |
# print(SQL)
|
456 |
+
res = await self.db.query(SQL)
|
457 |
if res:
|
458 |
# print("Edge exist!",res)
|
459 |
return True
|
|
|
463 |
|
464 |
async def node_degree(self, node_id: str) -> int:
|
465 |
"""根据节点id获取节点的度"""
|
466 |
+
SQL = SQL_TEMPLATES["node_degree"].format(
|
467 |
+
workspace=self.db.workspace, node_id=node_id
|
468 |
+
)
|
|
|
|
|
469 |
# print(SQL)
|
470 |
+
res = await self.db.query(SQL)
|
471 |
if res:
|
472 |
# print("Node degree",res["degree"])
|
473 |
return res["degree"]
|
|
|
483 |
|
484 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
485 |
"""根据节点id获取节点数据"""
|
486 |
+
SQL = SQL_TEMPLATES["get_node"].format(
|
487 |
+
workspace=self.db.workspace, node_id=node_id
|
488 |
+
)
|
|
|
|
|
489 |
# print(self.db.workspace, node_id)
|
490 |
# print(SQL)
|
491 |
+
res = await self.db.query(SQL)
|
492 |
if res:
|
493 |
# print("Get node!",self.db.workspace, node_id,res)
|
494 |
return res
|
|
|
500 |
self, source_node_id: str, target_node_id: str
|
501 |
) -> Union[dict, None]:
|
502 |
"""根据源和目标节点id获取边"""
|
503 |
+
SQL = SQL_TEMPLATES["get_edge"].format(
|
504 |
+
workspace=self.db.workspace,
|
505 |
+
source_node_id=source_node_id,
|
506 |
+
target_node_id=target_node_id,
|
507 |
+
)
|
508 |
+
res = await self.db.query(SQL)
|
|
|
509 |
if res:
|
510 |
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
511 |
return res
|
|
|
516 |
async def get_node_edges(self, source_node_id: str):
|
517 |
"""根据节点id获取节点的所有边"""
|
518 |
if await self.has_node(source_node_id):
|
519 |
+
SQL = SQL_TEMPLATES["get_node_edges"].format(
|
520 |
+
workspace=self.db.workspace, source_node_id=source_node_id
|
521 |
+
)
|
522 |
+
res = await self.db.query(sql=SQL, multirows=True)
|
|
|
|
|
523 |
if res:
|
524 |
data = [(i["source_name"], i["target_name"]) for i in res]
|
525 |
# print("Get node edge!",self.db.workspace, source_node_id,data)
|
|
|
527 |
else:
|
528 |
# print("Node Edge not exist!",self.db.workspace, source_node_id)
|
529 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
531 |
|
532 |
N_T = {
|
533 |
"full_docs": "LIGHTRAG_DOC_FULL",
|
|
|
699 |
ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
|
700 |
WHEN NOT MATCHED THEN
|
701 |
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
702 |
+
values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
703 |
}
|
lightrag/lightrag.py
CHANGED
@@ -172,9 +172,7 @@ class LightRAG:
|
|
172 |
embedding_func=self.embedding_func,
|
173 |
)
|
174 |
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
175 |
-
namespace="chunk_entity_relation",
|
176 |
-
global_config=asdict(self),
|
177 |
-
embedding_func=self.embedding_func,
|
178 |
)
|
179 |
####
|
180 |
# add embedding func by walter over
|
@@ -226,6 +224,7 @@ class LightRAG:
|
|
226 |
return loop.run_until_complete(self.ainsert(string_or_strings))
|
227 |
|
228 |
async def ainsert(self, string_or_strings):
|
|
|
229 |
try:
|
230 |
if isinstance(string_or_strings, str):
|
231 |
string_or_strings = [string_or_strings]
|
@@ -239,6 +238,7 @@ class LightRAG:
|
|
239 |
if not len(new_docs):
|
240 |
logger.warning("All docs are already in the storage")
|
241 |
return
|
|
|
242 |
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
243 |
|
244 |
inserting_chunks = {}
|
@@ -285,7 +285,8 @@ class LightRAG:
|
|
285 |
await self.full_docs.upsert(new_docs)
|
286 |
await self.text_chunks.upsert(inserting_chunks)
|
287 |
finally:
|
288 |
-
|
|
|
289 |
|
290 |
async def _insert_done(self):
|
291 |
tasks = []
|
|
|
172 |
embedding_func=self.embedding_func,
|
173 |
)
|
174 |
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
175 |
+
namespace="chunk_entity_relation", global_config=asdict(self)
|
|
|
|
|
176 |
)
|
177 |
####
|
178 |
# add embedding func by walter over
|
|
|
224 |
return loop.run_until_complete(self.ainsert(string_or_strings))
|
225 |
|
226 |
async def ainsert(self, string_or_strings):
|
227 |
+
update_storage = False
|
228 |
try:
|
229 |
if isinstance(string_or_strings, str):
|
230 |
string_or_strings = [string_or_strings]
|
|
|
238 |
if not len(new_docs):
|
239 |
logger.warning("All docs are already in the storage")
|
240 |
return
|
241 |
+
update_storage = True
|
242 |
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
243 |
|
244 |
inserting_chunks = {}
|
|
|
285 |
await self.full_docs.upsert(new_docs)
|
286 |
await self.text_chunks.upsert(inserting_chunks)
|
287 |
finally:
|
288 |
+
if update_storage:
|
289 |
+
await self._insert_done()
|
290 |
|
291 |
async def _insert_done(self):
|
292 |
tasks = []
|
lightrag/llm.py
CHANGED
@@ -696,13 +696,17 @@ async def bedrock_embedding(
|
|
696 |
|
697 |
|
698 |
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
|
|
699 |
input_ids = tokenizer(
|
700 |
texts, return_tensors="pt", padding=True, truncation=True
|
701 |
-
).input_ids
|
702 |
with torch.no_grad():
|
703 |
outputs = embed_model(input_ids)
|
704 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
705 |
-
|
|
|
|
|
|
|
706 |
|
707 |
|
708 |
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
|
|
696 |
|
697 |
|
698 |
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
699 |
+
device = next(embed_model.parameters()).device
|
700 |
input_ids = tokenizer(
|
701 |
texts, return_tensors="pt", padding=True, truncation=True
|
702 |
+
).input_ids.to(device)
|
703 |
with torch.no_grad():
|
704 |
outputs = embed_model(input_ids)
|
705 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
706 |
+
if embeddings.dtype == torch.bfloat16:
|
707 |
+
return embeddings.detach().to(torch.float32).cpu().numpy()
|
708 |
+
else:
|
709 |
+
return embeddings.detach().cpu().numpy()
|
710 |
|
711 |
|
712 |
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
lightrag/operate.py
CHANGED
@@ -662,24 +662,20 @@ async def _find_most_related_text_unit_from_entities(
|
|
662 |
all_text_units_lookup = {}
|
663 |
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
|
664 |
for c_id in this_text_units:
|
665 |
-
if c_id in all_text_units_lookup:
|
666 |
-
|
667 |
-
|
668 |
-
|
|
|
|
|
|
|
|
|
669 |
for e in this_edges:
|
670 |
if (
|
671 |
e[1] in all_one_hop_text_units_lookup
|
672 |
and c_id in all_one_hop_text_units_lookup[e[1]]
|
673 |
):
|
674 |
-
relation_counts += 1
|
675 |
-
|
676 |
-
chunk_data = await text_chunks_db.get_by_id(c_id)
|
677 |
-
if chunk_data is not None and "content" in chunk_data: # Add content check
|
678 |
-
all_text_units_lookup[c_id] = {
|
679 |
-
"data": chunk_data,
|
680 |
-
"order": index,
|
681 |
-
"relation_counts": relation_counts,
|
682 |
-
}
|
683 |
|
684 |
# Filter out None values and ensure data has content
|
685 |
all_text_units = [
|
@@ -714,10 +710,16 @@ async def _find_most_related_edges_from_entities(
|
|
714 |
all_related_edges = await asyncio.gather(
|
715 |
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
716 |
)
|
717 |
-
all_edges =
|
|
|
|
|
718 |
for this_edges in all_related_edges:
|
719 |
-
|
720 |
-
|
|
|
|
|
|
|
|
|
721 |
all_edges_pack = await asyncio.gather(
|
722 |
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
|
723 |
)
|
@@ -828,10 +830,16 @@ async def _find_most_related_entities_from_relationships(
|
|
828 |
query_param: QueryParam,
|
829 |
knowledge_graph_inst: BaseGraphStorage,
|
830 |
):
|
831 |
-
entity_names =
|
|
|
|
|
832 |
for e in edge_datas:
|
833 |
-
|
834 |
-
|
|
|
|
|
|
|
|
|
835 |
|
836 |
node_datas = await asyncio.gather(
|
837 |
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
|
|
662 |
all_text_units_lookup = {}
|
663 |
for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
|
664 |
for c_id in this_text_units:
|
665 |
+
if c_id not in all_text_units_lookup:
|
666 |
+
all_text_units_lookup[c_id] = {
|
667 |
+
"data": await text_chunks_db.get_by_id(c_id),
|
668 |
+
"order": index,
|
669 |
+
"relation_counts": 0,
|
670 |
+
}
|
671 |
+
|
672 |
+
if this_edges:
|
673 |
for e in this_edges:
|
674 |
if (
|
675 |
e[1] in all_one_hop_text_units_lookup
|
676 |
and c_id in all_one_hop_text_units_lookup[e[1]]
|
677 |
):
|
678 |
+
all_text_units_lookup[c_id]["relation_counts"] += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
679 |
|
680 |
# Filter out None values and ensure data has content
|
681 |
all_text_units = [
|
|
|
710 |
all_related_edges = await asyncio.gather(
|
711 |
*[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
|
712 |
)
|
713 |
+
all_edges = []
|
714 |
+
seen = set()
|
715 |
+
|
716 |
for this_edges in all_related_edges:
|
717 |
+
for e in this_edges:
|
718 |
+
sorted_edge = tuple(sorted(e))
|
719 |
+
if sorted_edge not in seen:
|
720 |
+
seen.add(sorted_edge)
|
721 |
+
all_edges.append(sorted_edge)
|
722 |
+
|
723 |
all_edges_pack = await asyncio.gather(
|
724 |
*[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
|
725 |
)
|
|
|
830 |
query_param: QueryParam,
|
831 |
knowledge_graph_inst: BaseGraphStorage,
|
832 |
):
|
833 |
+
entity_names = []
|
834 |
+
seen = set()
|
835 |
+
|
836 |
for e in edge_datas:
|
837 |
+
if e["src_id"] not in seen:
|
838 |
+
entity_names.append(e["src_id"])
|
839 |
+
seen.add(e["src_id"])
|
840 |
+
if e["tgt_id"] not in seen:
|
841 |
+
entity_names.append(e["tgt_id"])
|
842 |
+
seen.add(e["tgt_id"])
|
843 |
|
844 |
node_datas = await asyncio.gather(
|
845 |
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
lightrag/utils.py
CHANGED
@@ -290,13 +290,19 @@ def process_combine_contexts(hl, ll):
|
|
290 |
if list_ll:
|
291 |
list_ll = [",".join(item[1:]) for item in list_ll if item]
|
292 |
|
293 |
-
|
|
|
294 |
|
295 |
-
|
|
|
|
|
|
|
296 |
|
297 |
-
|
298 |
-
combined_sources.append(f"{i},\t{item}")
|
299 |
|
300 |
-
|
|
|
301 |
|
302 |
-
|
|
|
|
|
|
290 |
if list_ll:
|
291 |
list_ll = [",".join(item[1:]) for item in list_ll if item]
|
292 |
|
293 |
+
combined_sources = []
|
294 |
+
seen = set()
|
295 |
|
296 |
+
for item in list_hl + list_ll:
|
297 |
+
if item and item not in seen:
|
298 |
+
combined_sources.append(item)
|
299 |
+
seen.add(item)
|
300 |
|
301 |
+
combined_sources_result = [",\t".join(header)]
|
|
|
302 |
|
303 |
+
for i, item in enumerate(combined_sources, start=1):
|
304 |
+
combined_sources_result.append(f"{i},\t{item}")
|
305 |
|
306 |
+
combined_sources_result = "\n".join(combined_sources_result)
|
307 |
+
|
308 |
+
return combined_sources_result
|
reproduce/Step_1.py
CHANGED
@@ -24,7 +24,7 @@ def insert_text(rag, file_path):
|
|
24 |
|
25 |
|
26 |
cls = "agriculture"
|
27 |
-
WORKING_DIR = "../{cls}"
|
28 |
|
29 |
if not os.path.exists(WORKING_DIR):
|
30 |
os.mkdir(WORKING_DIR)
|
|
|
24 |
|
25 |
|
26 |
cls = "agriculture"
|
27 |
+
WORKING_DIR = f"../{cls}"
|
28 |
|
29 |
if not os.path.exists(WORKING_DIR):
|
30 |
os.mkdir(WORKING_DIR)
|