jin commited on
Commit
ac47ddf
·
2 Parent(s): c5aaad2 3b1c501

Merge branch 'main' of https://github.com/jin38324/LightRAG

Browse files
README.md CHANGED
@@ -8,7 +8,7 @@
8
  <a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
9
  <a href='https://youtu.be/oageL-1I0GE'><img src='https://badges.aleen42.com/src/youtube.svg'></a>
10
  <a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
11
- <a href='https://discord.gg/yF2MmDJyGJ'><img src='https://discordapp.com/api/guilds/1296348098003734629/widget.png?style=shield'></a>
12
  </p>
13
  <p>
14
  <img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
@@ -16,27 +16,34 @@
16
  <a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
17
  <a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
18
  </p>
 
 
 
 
19
 
20
  This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
21
  ![LightRAG Diagram](https://i-blog.csdnimg.cn/direct/b2aaf634151b4706892693ffb43d9093.png)
22
  </div>
23
 
24
  ## 🎉 News
25
- - [x] [2024.11.12]🎯📢You can [use Oracle Database 23ai for all storage types (kv/vector/graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py) now.
 
26
  - [x] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete-entity).
27
  - [x] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
28
  - [x] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
29
  - [x] [2024.10.29]🎯📢LightRAG now supports multiple file types, including PDF, DOC, PPT, and CSV via `textract`.
30
  - [x] [2024.10.20]🎯📢We’ve added a new feature to LightRAG: Graph Visualization.
31
  - [x] [2024.10.18]🎯📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
32
- - [x] [2024.10.17]🎯📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
33
  - [x] [2024.10.16]🎯📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
34
  - [x] [2024.10.15]🎯📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
35
 
36
  ## Algorithm Flowchart
37
 
38
- ![LightRAG_Self excalidraw](https://github.com/user-attachments/assets/aa5c4892-2e44-49e6-a116-2403ed80a1a3)
39
-
 
 
40
 
41
  ## Install
42
 
 
8
  <a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
9
  <a href='https://youtu.be/oageL-1I0GE'><img src='https://badges.aleen42.com/src/youtube.svg'></a>
10
  <a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
11
+ <a href='https://learnopencv.com/lightrag'><img src='https://img.shields.io/badge/LearnOpenCV-blue'></a>
12
  </p>
13
  <p>
14
  <img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
 
16
  <a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
17
  <a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
18
  </p>
19
+ <p>
20
+ <a href='https://discord.gg/yF2MmDJyGJ'><img src='https://discordapp.com/api/guilds/1296348098003734629/widget.png?style=shield'></a>
21
+ <a href='https://github.com/HKUDS/LightRAG/issues/285'><img src='https://img.shields.io/badge/群聊-wechat-green'></a>
22
+ </p>
23
 
24
  This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
25
  ![LightRAG Diagram](https://i-blog.csdnimg.cn/direct/b2aaf634151b4706892693ffb43d9093.png)
26
  </div>
27
 
28
  ## 🎉 News
29
+ - [x] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author!
30
+ - [x] [2024.11.12]🎯📢LightRAG now supports [Oracle Database 23ai for all storage types (KV, vector, and graph)](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_oracle_demo.py).
31
  - [x] [2024.11.11]🎯📢LightRAG now supports [deleting entities by their names](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete-entity).
32
  - [x] [2024.11.09]🎯📢Introducing the [LightRAG Gui](https://lightrag-gui.streamlit.app), which allows you to insert, query, visualize, and download LightRAG knowledge.
33
  - [x] [2024.11.04]🎯📢You can now [use Neo4J for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-neo4j-for-storage).
34
  - [x] [2024.10.29]🎯📢LightRAG now supports multiple file types, including PDF, DOC, PPT, and CSV via `textract`.
35
  - [x] [2024.10.20]🎯📢We’ve added a new feature to LightRAG: Graph Visualization.
36
  - [x] [2024.10.18]🎯📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
37
+ - [x] [2024.10.17]🎯📢We have created a [Discord channel](https://discord.gg/yF2MmDJyGJ)! Welcome to join for sharing and discussions! 🎉🎉
38
  - [x] [2024.10.16]🎯📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
39
  - [x] [2024.10.15]🎯📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
40
 
41
  ## Algorithm Flowchart
42
 
43
+ ![LightRAG Indexing Flowchart](https://learnopencv.com/wp-content/uploads/2024/11/LightRAG-VectorDB-Json-KV-Store-Indexing-Flowchart-scaled.jpg)
44
+ *Figure 1: LightRAG Indexing Flowchart*
45
+ ![LightRAG Retrieval and Querying Flowchart](https://learnopencv.com/wp-content/uploads/2024/11/LightRAG-Querying-Flowchart-Dual-Level-Retrieval-Generation-Knowledge-Graphs-scaled.jpg)
46
+ *Figure 2: LightRAG Retrieval and Querying Flowchart*
47
 
48
  ## Install
49
 
examples/lightrag_api_oracle_demo..py CHANGED
@@ -162,12 +162,12 @@ class Response(BaseModel):
162
 
163
  # API routes
164
 
165
- rag = None # 定义为全局对象
166
 
167
  @asynccontextmanager
168
  async def lifespan(app: FastAPI):
169
  global rag
170
- rag = await init() # 在应用启动时初始化 `rag`
171
  print("done!")
172
  yield
173
 
 
162
 
163
  # API routes
164
 
165
+ rag = None
166
 
167
  @asynccontextmanager
168
  async def lifespan(app: FastAPI):
169
  global rag
170
+ rag = await init()
171
  print("done!")
172
  yield
173
 
examples/lightrag_azure_openai_demo.py CHANGED
@@ -4,8 +4,8 @@ from lightrag import LightRAG, QueryParam
4
  from lightrag.utils import EmbeddingFunc
5
  import numpy as np
6
  from dotenv import load_dotenv
7
- import aiohttp
8
  import logging
 
9
 
10
  logging.basicConfig(level=logging.INFO)
11
 
@@ -32,11 +32,11 @@ os.mkdir(WORKING_DIR)
32
  async def llm_model_func(
33
  prompt, system_prompt=None, history_messages=[], **kwargs
34
  ) -> str:
35
- headers = {
36
- "Content-Type": "application/json",
37
- "api-key": AZURE_OPENAI_API_KEY,
38
- }
39
- endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version={AZURE_OPENAI_API_VERSION}"
40
 
41
  messages = []
42
  if system_prompt:
@@ -45,41 +45,26 @@ async def llm_model_func(
45
  messages.extend(history_messages)
46
  messages.append({"role": "user", "content": prompt})
47
 
48
- payload = {
49
- "messages": messages,
50
- "temperature": kwargs.get("temperature", 0),
51
- "top_p": kwargs.get("top_p", 1),
52
- "n": kwargs.get("n", 1),
53
- }
54
-
55
- async with aiohttp.ClientSession() as session:
56
- async with session.post(endpoint, headers=headers, json=payload) as response:
57
- if response.status != 200:
58
- raise ValueError(
59
- f"Request failed with status {response.status}: {await response.text()}"
60
- )
61
- result = await response.json()
62
- return result["choices"][0]["message"]["content"]
63
 
64
 
65
  async def embedding_func(texts: list[str]) -> np.ndarray:
66
- headers = {
67
- "Content-Type": "application/json",
68
- "api-key": AZURE_OPENAI_API_KEY,
69
- }
70
- endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_EMBEDDING_DEPLOYMENT}/embeddings?api-version={AZURE_EMBEDDING_API_VERSION}"
71
-
72
- payload = {"input": texts}
73
-
74
- async with aiohttp.ClientSession() as session:
75
- async with session.post(endpoint, headers=headers, json=payload) as response:
76
- if response.status != 200:
77
- raise ValueError(
78
- f"Request failed with status {response.status}: {await response.text()}"
79
- )
80
- result = await response.json()
81
- embeddings = [item["embedding"] for item in result["data"]]
82
- return np.array(embeddings)
83
 
84
 
85
  async def test_funcs():
 
4
  from lightrag.utils import EmbeddingFunc
5
  import numpy as np
6
  from dotenv import load_dotenv
 
7
  import logging
8
+ from openai import AzureOpenAI
9
 
10
  logging.basicConfig(level=logging.INFO)
11
 
 
32
  async def llm_model_func(
33
  prompt, system_prompt=None, history_messages=[], **kwargs
34
  ) -> str:
35
+ client = AzureOpenAI(
36
+ api_key=AZURE_OPENAI_API_KEY,
37
+ api_version=AZURE_OPENAI_API_VERSION,
38
+ azure_endpoint=AZURE_OPENAI_ENDPOINT,
39
+ )
40
 
41
  messages = []
42
  if system_prompt:
 
45
  messages.extend(history_messages)
46
  messages.append({"role": "user", "content": prompt})
47
 
48
+ chat_completion = client.chat.completions.create(
49
+ model=AZURE_OPENAI_DEPLOYMENT, # model = "deployment_name".
50
+ messages=messages,
51
+ temperature=kwargs.get("temperature", 0),
52
+ top_p=kwargs.get("top_p", 1),
53
+ n=kwargs.get("n", 1),
54
+ )
55
+ return chat_completion.choices[0].message.content
 
 
 
 
 
 
 
56
 
57
 
58
  async def embedding_func(texts: list[str]) -> np.ndarray:
59
+ client = AzureOpenAI(
60
+ api_key=AZURE_OPENAI_API_KEY,
61
+ api_version=AZURE_EMBEDDING_API_VERSION,
62
+ azure_endpoint=AZURE_OPENAI_ENDPOINT,
63
+ )
64
+ embedding = client.embeddings.create(model=AZURE_EMBEDDING_DEPLOYMENT, input=texts)
65
+
66
+ embeddings = [item.embedding for item in embedding.data]
67
+ return np.array(embeddings)
 
 
 
 
 
 
 
 
68
 
69
 
70
  async def test_funcs():
lightrag/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
2
 
3
- __version__ = "1.0.0"
4
  __author__ = "Zirui Guo"
5
  __url__ = "https://github.com/HKUDS/LightRAG"
 
1
  from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
2
 
3
+ __version__ = "1.0.1"
4
  __author__ = "Zirui Guo"
5
  __url__ = "https://github.com/HKUDS/LightRAG"
lightrag/kg/neo4j_impl.py CHANGED
@@ -86,9 +86,6 @@ class Neo4JStorage(BaseGraphStorage):
86
  )
87
  return single_result["edgeExists"]
88
 
89
- def close(self):
90
- self._driver.close()
91
-
92
  async def get_node(self, node_id: str) -> Union[dict, None]:
93
  async with self._driver.session() as session:
94
  entity_name_label = node_id.strip('"')
@@ -214,6 +211,7 @@ class Neo4JStorage(BaseGraphStorage):
214
  neo4jExceptions.ServiceUnavailable,
215
  neo4jExceptions.TransientError,
216
  neo4jExceptions.WriteServiceUnavailable,
 
217
  )
218
  ),
219
  )
 
86
  )
87
  return single_result["edgeExists"]
88
 
 
 
 
89
  async def get_node(self, node_id: str) -> Union[dict, None]:
90
  async with self._driver.session() as session:
91
  entity_name_label = node_id.strip('"')
 
211
  neo4jExceptions.ServiceUnavailable,
212
  neo4jExceptions.TransientError,
213
  neo4jExceptions.WriteServiceUnavailable,
214
+ neo4jExceptions.ClientError,
215
  )
216
  ),
217
  )
lightrag/kg/oracle_impl.py CHANGED
@@ -114,7 +114,7 @@ class OracleDB:
114
 
115
  logger.info("Finished check all tables in Oracle database")
116
 
117
- async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
118
  async with self.pool.acquire() as connection:
119
  connection.inputtypehandler = self.input_type_handler
120
  connection.outputtypehandler = self.output_type_handler
@@ -173,10 +173,11 @@ class OracleKVStorage(BaseKVStorage):
173
 
174
  async def get_by_id(self, id: str) -> Union[dict, None]:
175
  """根据 id 获取 doc_full 数据."""
176
- SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
177
- params = {"workspace":self.db.workspace, "id":id}
 
178
  # print("get_by_id:"+SQL)
179
- res = await self.db.query(SQL,params)
180
  if res:
181
  data = res # {"data":res}
182
  # print (data)
@@ -187,11 +188,11 @@ class OracleKVStorage(BaseKVStorage):
187
  # Query by id
188
  async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
189
  """根据 id 获取 doc_chunks 数据"""
190
- SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
191
- params = {"workspace":self.db.workspace}
192
- #print("get_by_ids:"+SQL)
193
- #print(params)
194
- res = await self.db.query(SQL,params, multirows=True)
195
  if res:
196
  data = res # [{"data":i} for i in res]
197
  # print(data)
@@ -201,16 +202,12 @@ class OracleKVStorage(BaseKVStorage):
201
 
202
  async def filter_keys(self, keys: list[str]) -> set[str]:
203
  """过滤掉重复内容"""
204
- SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
205
- ids=",".join([f"'{id}'" for id in keys]))
206
- params = {"workspace":self.db.workspace}
207
- try:
208
- await self.db.query(SQL, params)
209
- except Exception as e:
210
- logger.error(f"Oracle database error: {e}")
211
- print(SQL)
212
- print(params)
213
- res = await self.db.query(SQL, params,multirows=True)
214
  data = None
215
  if res:
216
  exist_keys = [key["id"] for key in res]
@@ -247,29 +244,27 @@ class OracleKVStorage(BaseKVStorage):
247
  d["__vector__"] = embeddings[i]
248
  # print(list_data)
249
  for item in list_data:
250
- merge_sql = SQL_TEMPLATES["merge_chunk"]
251
- data = {"check_id":item["__id__"],
252
- "id":item["__id__"],
253
- "content":item["content"],
254
- "workspace":self.db.workspace,
255
- "tokens":item["tokens"],
256
- "chunk_order_index":item["chunk_order_index"],
257
- "full_doc_id":item["full_doc_id"],
258
- "content_vector":item["__vector__"]
259
- }
 
260
  # print(merge_sql)
261
  await self.db.execute(merge_sql, data)
262
 
263
  if self.namespace == "full_docs":
264
  for k, v in self._data.items():
265
  # values.clear()
266
- merge_sql = SQL_TEMPLATES["merge_doc_full"]
267
- data = {
268
- "check_id":k,
269
- "id":k,
270
- "content":v["content"],
271
- "workspace":self.db.workspace
272
- }
273
  # print(merge_sql)
274
  await self.db.execute(merge_sql, data)
275
  return left_data
@@ -301,17 +296,18 @@ class OracleVectorDBStorage(BaseVectorStorage):
301
  # 转换精度
302
  dtype = str(embedding.dtype).upper()
303
  dimension = embedding.shape[0]
304
- embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]"
305
-
306
- SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
307
- params = {
308
- "embedding_string": embedding_string,
309
- "workspace": self.db.workspace,
310
- "top_k": top_k,
311
- "better_than_threshold": self.cosine_better_than_threshold,
312
- }
 
313
  # print(SQL)
314
- results = await self.db.query(SQL,params=params, multirows=True)
315
  # print("vector search result:",results)
316
  return results
317
 
@@ -346,18 +342,22 @@ class OracleGraphStorage(BaseGraphStorage):
346
  )
347
  embeddings = np.concatenate(embeddings_list)
348
  content_vector = embeddings[0]
349
- merge_sql = SQL_TEMPLATES["merge_node"]
350
- data = {
351
- "workspace":self.db.workspace,
352
- "name":entity_name,
353
- "entity_type":entity_type,
354
- "description":description,
355
- "source_chunk_id":source_id,
356
- "content":content,
357
- "content_vector":content_vector
358
- }
359
  # print(merge_sql)
360
- await self.db.execute(merge_sql,data)
 
 
 
 
 
 
 
 
 
 
 
361
  # self._graph.add_node(node_id, **node_data)
362
 
363
  async def upsert_edge(
@@ -371,8 +371,6 @@ class OracleGraphStorage(BaseGraphStorage):
371
  keywords = edge_data["keywords"]
372
  description = edge_data["description"]
373
  source_chunk_id = edge_data["source_id"]
374
- logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
375
-
376
  content = keywords + source_name + target_name + description
377
  contents = [content]
378
  batches = [
@@ -384,20 +382,27 @@ class OracleGraphStorage(BaseGraphStorage):
384
  )
385
  embeddings = np.concatenate(embeddings_list)
386
  content_vector = embeddings[0]
387
- merge_sql = SQL_TEMPLATES["merge_edge"]
388
- data = {
389
- "workspace":self.db.workspace,
390
- "source_name":source_name,
391
- "target_name":target_name,
392
- "weight":weight,
393
- "keywords":keywords,
394
- "description":description,
395
- "source_chunk_id":source_chunk_id,
396
- "content":content,
397
- "content_vector":content_vector
398
- }
399
  # print(merge_sql)
400
- await self.db.execute(merge_sql,data)
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
402
 
403
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
@@ -427,14 +432,12 @@ class OracleGraphStorage(BaseGraphStorage):
427
  #################### query method #################
428
  async def has_node(self, node_id: str) -> bool:
429
  """根据节点id检查节点是否存在"""
430
- SQL = SQL_TEMPLATES["has_node"]
431
- params = {
432
- "workspace":self.db.workspace,
433
- "node_id":node_id
434
- }
435
  # print(SQL)
436
  # print(self.db.workspace, node_id)
437
- res = await self.db.query(SQL,params)
438
  if res:
439
  # print("Node exist!",res)
440
  return True
@@ -444,14 +447,13 @@ class OracleGraphStorage(BaseGraphStorage):
444
 
445
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
446
  """根据源和目标节点id检查边是否存在"""
447
- SQL = SQL_TEMPLATES["has_edge"]
448
- params = {
449
- "workspace":self.db.workspace,
450
- "source_node_id":source_node_id,
451
- "target_node_id":target_node_id
452
- }
453
  # print(SQL)
454
- res = await self.db.query(SQL,params)
455
  if res:
456
  # print("Edge exist!",res)
457
  return True
@@ -461,13 +463,11 @@ class OracleGraphStorage(BaseGraphStorage):
461
 
462
  async def node_degree(self, node_id: str) -> int:
463
  """根据节点id获取节点的度"""
464
- SQL = SQL_TEMPLATES["node_degree"]
465
- params = {
466
- "workspace":self.db.workspace,
467
- "node_id":node_id
468
- }
469
  # print(SQL)
470
- res = await self.db.query(SQL,params)
471
  if res:
472
  # print("Node degree",res["degree"])
473
  return res["degree"]
@@ -483,14 +483,12 @@ class OracleGraphStorage(BaseGraphStorage):
483
 
484
  async def get_node(self, node_id: str) -> Union[dict, None]:
485
  """根据节点id获取节点数据"""
486
- SQL = SQL_TEMPLATES["get_node"]
487
- params = {
488
- "workspace":self.db.workspace,
489
- "node_id":node_id
490
- }
491
  # print(self.db.workspace, node_id)
492
  # print(SQL)
493
- res = await self.db.query(SQL,params)
494
  if res:
495
  # print("Get node!",self.db.workspace, node_id,res)
496
  return res
@@ -502,13 +500,12 @@ class OracleGraphStorage(BaseGraphStorage):
502
  self, source_node_id: str, target_node_id: str
503
  ) -> Union[dict, None]:
504
  """根据源和目标节点id获取边"""
505
- SQL = SQL_TEMPLATES["get_edge"]
506
- params = {
507
- "workspace":self.db.workspace,
508
- "source_node_id":source_node_id,
509
- "target_node_id":target_node_id
510
- }
511
- res = await self.db.query(SQL,params)
512
  if res:
513
  # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
514
  return res
@@ -519,12 +516,10 @@ class OracleGraphStorage(BaseGraphStorage):
519
  async def get_node_edges(self, source_node_id: str):
520
  """根据节点id获取节点的所有边"""
521
  if await self.has_node(source_node_id):
522
- SQL = SQL_TEMPLATES["get_node_edges"]
523
- params = {
524
- "workspace":self.db.workspace,
525
- "source_node_id":source_node_id
526
- }
527
- res = await self.db.query(sql=SQL, params=params, multirows=True)
528
  if res:
529
  data = [(i["source_name"], i["target_name"]) for i in res]
530
  # print("Get node edge!",self.db.workspace, source_node_id,data)
@@ -532,29 +527,7 @@ class OracleGraphStorage(BaseGraphStorage):
532
  else:
533
  # print("Node Edge not exist!",self.db.workspace, source_node_id)
534
  return []
535
-
536
- async def get_all_nodes(self, limit: int):
537
- """查询所有节点"""
538
- SQL = SQL_TEMPLATES["get_all_nodes"]
539
- params = {"workspace":self.db.workspace, "limit":str(limit)}
540
- res = await self.db.query(sql=SQL,params=params, multirows=True)
541
- if res:
542
- return res
543
 
544
- async def get_all_edges(self, limit: int):
545
- """查询所有边"""
546
- SQL = SQL_TEMPLATES["get_all_edges"]
547
- params = {"workspace":self.db.workspace, "limit":str(limit)}
548
- res = await self.db.query(sql=SQL,params=params, multirows=True)
549
- if res:
550
- return res
551
-
552
- async def get_statistics(self):
553
- SQL = SQL_TEMPLATES["get_statistics"]
554
- params = {"workspace":self.db.workspace}
555
- res = await self.db.query(sql=SQL,params=params, multirows=True)
556
- if res:
557
- return res
558
 
559
  N_T = {
560
  "full_docs": "LIGHTRAG_DOC_FULL",
@@ -726,37 +699,5 @@ SQL_TEMPLATES = {
726
  ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
727
  WHEN NOT MATCHED THEN
728
  INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
729
- values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
730
- "get_all_nodes":"""WITH t0 AS (
731
- SELECT name AS id, entity_type AS label, entity_type, description,
732
- '["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
733
- FROM lightrag_graph_nodes
734
- WHERE workspace = :workspace
735
- ORDER BY createtime DESC fetch first :limit rows only
736
- ), t1 AS (
737
- SELECT t0.id, source_chunk_id
738
- FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
739
- ), t2 AS (
740
- SELECT t1.id, LISTAGG(t2.content, '\n') content
741
- FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
742
- GROUP BY t1.id
743
- )
744
- SELECT t0.id, label, entity_type, description, t2.content
745
- FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
746
- "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
747
- t1.weight,t1.DESCRIPTION,t2.content
748
- FROM LIGHTRAG_GRAPH_EDGES t1
749
- LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
750
- WHERE t1.workspace=:workspace
751
- order by t1.CREATETIME DESC
752
- fetch first :limit rows only""",
753
- "get_statistics":"""select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
754
- count(distinct CASE WHEN type='edge' THEN id END) as edges_count
755
- FROM (
756
- select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
757
- MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
758
- UNION
759
- select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
760
- MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
761
- )""",
762
  }
 
114
 
115
  logger.info("Finished check all tables in Oracle database")
116
 
117
+ async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
118
  async with self.pool.acquire() as connection:
119
  connection.inputtypehandler = self.input_type_handler
120
  connection.outputtypehandler = self.output_type_handler
 
173
 
174
  async def get_by_id(self, id: str) -> Union[dict, None]:
175
  """根据 id 获取 doc_full 数据."""
176
+ SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
177
+ workspace=self.db.workspace, id=id
178
+ )
179
  # print("get_by_id:"+SQL)
180
+ res = await self.db.query(SQL)
181
  if res:
182
  data = res # {"data":res}
183
  # print (data)
 
188
  # Query by id
189
  async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
190
  """根据 id 获取 doc_chunks 数据"""
191
+ SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
192
+ workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
193
+ )
194
+ # print("get_by_ids:"+SQL)
195
+ res = await self.db.query(SQL, multirows=True)
196
  if res:
197
  data = res # [{"data":i} for i in res]
198
  # print(data)
 
202
 
203
  async def filter_keys(self, keys: list[str]) -> set[str]:
204
  """过滤掉重复内容"""
205
+ SQL = SQL_TEMPLATES["filter_keys"].format(
206
+ table_name=N_T[self.namespace],
207
+ workspace=self.db.workspace,
208
+ ids=",".join([f"'{k}'" for k in keys]),
209
+ )
210
+ res = await self.db.query(SQL, multirows=True)
 
 
 
 
211
  data = None
212
  if res:
213
  exist_keys = [key["id"] for key in res]
 
244
  d["__vector__"] = embeddings[i]
245
  # print(list_data)
246
  for item in list_data:
247
+ merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
248
+
249
+ values = [
250
+ item["__id__"],
251
+ item["content"],
252
+ self.db.workspace,
253
+ item["tokens"],
254
+ item["chunk_order_index"],
255
+ item["full_doc_id"],
256
+ item["__vector__"],
257
+ ]
258
  # print(merge_sql)
259
  await self.db.execute(merge_sql, data)
260
 
261
  if self.namespace == "full_docs":
262
  for k, v in self._data.items():
263
  # values.clear()
264
+ merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
265
+ check_id=k,
266
+ )
267
+ values = [k, self._data[k]["content"], self.db.workspace]
 
 
 
268
  # print(merge_sql)
269
  await self.db.execute(merge_sql, data)
270
  return left_data
 
296
  # 转换精度
297
  dtype = str(embedding.dtype).upper()
298
  dimension = embedding.shape[0]
299
+ embedding_string = ", ".join(map(str, embedding.tolist()))
300
+
301
+ SQL = SQL_TEMPLATES[self.namespace].format(
302
+ embedding_string=embedding_string,
303
+ dimension=dimension,
304
+ dtype=dtype,
305
+ workspace=self.db.workspace,
306
+ top_k=top_k,
307
+ better_than_threshold=self.cosine_better_than_threshold,
308
+ )
309
  # print(SQL)
310
+ results = await self.db.query(SQL, multirows=True)
311
  # print("vector search result:",results)
312
  return results
313
 
 
342
  )
343
  embeddings = np.concatenate(embeddings_list)
344
  content_vector = embeddings[0]
345
+ merge_sql = SQL_TEMPLATES["merge_node"].format(
346
+ workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
347
+ )
 
 
 
 
 
 
 
348
  # print(merge_sql)
349
+ await self.db.execute(
350
+ merge_sql,
351
+ [
352
+ self.db.workspace,
353
+ entity_name,
354
+ entity_type,
355
+ description,
356
+ source_id,
357
+ content,
358
+ content_vector,
359
+ ],
360
+ )
361
  # self._graph.add_node(node_id, **node_data)
362
 
363
  async def upsert_edge(
 
371
  keywords = edge_data["keywords"]
372
  description = edge_data["description"]
373
  source_chunk_id = edge_data["source_id"]
 
 
374
  content = keywords + source_name + target_name + description
375
  contents = [content]
376
  batches = [
 
382
  )
383
  embeddings = np.concatenate(embeddings_list)
384
  content_vector = embeddings[0]
385
+ merge_sql = SQL_TEMPLATES["merge_edge"].format(
386
+ workspace=self.db.workspace,
387
+ source_name=source_name,
388
+ target_name=target_name,
389
+ source_chunk_id=source_chunk_id,
390
+ )
 
 
 
 
 
 
391
  # print(merge_sql)
392
+ await self.db.execute(
393
+ merge_sql,
394
+ [
395
+ self.db.workspace,
396
+ source_name,
397
+ target_name,
398
+ weight,
399
+ keywords,
400
+ description,
401
+ source_chunk_id,
402
+ content,
403
+ content_vector,
404
+ ],
405
+ )
406
  # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
407
 
408
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
 
432
  #################### query method #################
433
  async def has_node(self, node_id: str) -> bool:
434
  """根据节点id检查节点是否存在"""
435
+ SQL = SQL_TEMPLATES["has_node"].format(
436
+ workspace=self.db.workspace, node_id=node_id
437
+ )
 
 
438
  # print(SQL)
439
  # print(self.db.workspace, node_id)
440
+ res = await self.db.query(SQL)
441
  if res:
442
  # print("Node exist!",res)
443
  return True
 
447
 
448
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
449
  """根据源和目标节点id检查边是否存在"""
450
+ SQL = SQL_TEMPLATES["has_edge"].format(
451
+ workspace=self.db.workspace,
452
+ source_node_id=source_node_id,
453
+ target_node_id=target_node_id,
454
+ )
 
455
  # print(SQL)
456
+ res = await self.db.query(SQL)
457
  if res:
458
  # print("Edge exist!",res)
459
  return True
 
463
 
464
  async def node_degree(self, node_id: str) -> int:
465
  """根据节点id获取节点的度"""
466
+ SQL = SQL_TEMPLATES["node_degree"].format(
467
+ workspace=self.db.workspace, node_id=node_id
468
+ )
 
 
469
  # print(SQL)
470
+ res = await self.db.query(SQL)
471
  if res:
472
  # print("Node degree",res["degree"])
473
  return res["degree"]
 
483
 
484
  async def get_node(self, node_id: str) -> Union[dict, None]:
485
  """根据节点id获取节点数据"""
486
+ SQL = SQL_TEMPLATES["get_node"].format(
487
+ workspace=self.db.workspace, node_id=node_id
488
+ )
 
 
489
  # print(self.db.workspace, node_id)
490
  # print(SQL)
491
+ res = await self.db.query(SQL)
492
  if res:
493
  # print("Get node!",self.db.workspace, node_id,res)
494
  return res
 
500
  self, source_node_id: str, target_node_id: str
501
  ) -> Union[dict, None]:
502
  """根据源和目标节点id获取边"""
503
+ SQL = SQL_TEMPLATES["get_edge"].format(
504
+ workspace=self.db.workspace,
505
+ source_node_id=source_node_id,
506
+ target_node_id=target_node_id,
507
+ )
508
+ res = await self.db.query(SQL)
 
509
  if res:
510
  # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
511
  return res
 
516
  async def get_node_edges(self, source_node_id: str):
517
  """根据节点id获取节点的所有边"""
518
  if await self.has_node(source_node_id):
519
+ SQL = SQL_TEMPLATES["get_node_edges"].format(
520
+ workspace=self.db.workspace, source_node_id=source_node_id
521
+ )
522
+ res = await self.db.query(sql=SQL, multirows=True)
 
 
523
  if res:
524
  data = [(i["source_name"], i["target_name"]) for i in res]
525
  # print("Get node edge!",self.db.workspace, source_node_id,data)
 
527
  else:
528
  # print("Node Edge not exist!",self.db.workspace, source_node_id)
529
  return []
 
 
 
 
 
 
 
 
530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
 
532
  N_T = {
533
  "full_docs": "LIGHTRAG_DOC_FULL",
 
699
  ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
700
  WHEN NOT MATCHED THEN
701
  INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
702
+ values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
703
  }
lightrag/lightrag.py CHANGED
@@ -172,9 +172,7 @@ class LightRAG:
172
  embedding_func=self.embedding_func,
173
  )
174
  self.chunk_entity_relation_graph = self.graph_storage_cls(
175
- namespace="chunk_entity_relation",
176
- global_config=asdict(self),
177
- embedding_func=self.embedding_func,
178
  )
179
  ####
180
  # add embedding func by walter over
@@ -226,6 +224,7 @@ class LightRAG:
226
  return loop.run_until_complete(self.ainsert(string_or_strings))
227
 
228
  async def ainsert(self, string_or_strings):
 
229
  try:
230
  if isinstance(string_or_strings, str):
231
  string_or_strings = [string_or_strings]
@@ -239,6 +238,7 @@ class LightRAG:
239
  if not len(new_docs):
240
  logger.warning("All docs are already in the storage")
241
  return
 
242
  logger.info(f"[New Docs] inserting {len(new_docs)} docs")
243
 
244
  inserting_chunks = {}
@@ -285,7 +285,8 @@ class LightRAG:
285
  await self.full_docs.upsert(new_docs)
286
  await self.text_chunks.upsert(inserting_chunks)
287
  finally:
288
- await self._insert_done()
 
289
 
290
  async def _insert_done(self):
291
  tasks = []
 
172
  embedding_func=self.embedding_func,
173
  )
174
  self.chunk_entity_relation_graph = self.graph_storage_cls(
175
+ namespace="chunk_entity_relation", global_config=asdict(self)
 
 
176
  )
177
  ####
178
  # add embedding func by walter over
 
224
  return loop.run_until_complete(self.ainsert(string_or_strings))
225
 
226
  async def ainsert(self, string_or_strings):
227
+ update_storage = False
228
  try:
229
  if isinstance(string_or_strings, str):
230
  string_or_strings = [string_or_strings]
 
238
  if not len(new_docs):
239
  logger.warning("All docs are already in the storage")
240
  return
241
+ update_storage = True
242
  logger.info(f"[New Docs] inserting {len(new_docs)} docs")
243
 
244
  inserting_chunks = {}
 
285
  await self.full_docs.upsert(new_docs)
286
  await self.text_chunks.upsert(inserting_chunks)
287
  finally:
288
+ if update_storage:
289
+ await self._insert_done()
290
 
291
  async def _insert_done(self):
292
  tasks = []
lightrag/llm.py CHANGED
@@ -696,13 +696,17 @@ async def bedrock_embedding(
696
 
697
 
698
  async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
 
699
  input_ids = tokenizer(
700
  texts, return_tensors="pt", padding=True, truncation=True
701
- ).input_ids
702
  with torch.no_grad():
703
  outputs = embed_model(input_ids)
704
  embeddings = outputs.last_hidden_state.mean(dim=1)
705
- return embeddings.detach().numpy()
 
 
 
706
 
707
 
708
  async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
 
696
 
697
 
698
  async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
699
+ device = next(embed_model.parameters()).device
700
  input_ids = tokenizer(
701
  texts, return_tensors="pt", padding=True, truncation=True
702
+ ).input_ids.to(device)
703
  with torch.no_grad():
704
  outputs = embed_model(input_ids)
705
  embeddings = outputs.last_hidden_state.mean(dim=1)
706
+ if embeddings.dtype == torch.bfloat16:
707
+ return embeddings.detach().to(torch.float32).cpu().numpy()
708
+ else:
709
+ return embeddings.detach().cpu().numpy()
710
 
711
 
712
  async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
lightrag/operate.py CHANGED
@@ -662,24 +662,20 @@ async def _find_most_related_text_unit_from_entities(
662
  all_text_units_lookup = {}
663
  for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
664
  for c_id in this_text_units:
665
- if c_id in all_text_units_lookup:
666
- continue
667
- relation_counts = 0
668
- if this_edges: # Add check for None edges
 
 
 
 
669
  for e in this_edges:
670
  if (
671
  e[1] in all_one_hop_text_units_lookup
672
  and c_id in all_one_hop_text_units_lookup[e[1]]
673
  ):
674
- relation_counts += 1
675
-
676
- chunk_data = await text_chunks_db.get_by_id(c_id)
677
- if chunk_data is not None and "content" in chunk_data: # Add content check
678
- all_text_units_lookup[c_id] = {
679
- "data": chunk_data,
680
- "order": index,
681
- "relation_counts": relation_counts,
682
- }
683
 
684
  # Filter out None values and ensure data has content
685
  all_text_units = [
@@ -714,10 +710,16 @@ async def _find_most_related_edges_from_entities(
714
  all_related_edges = await asyncio.gather(
715
  *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
716
  )
717
- all_edges = set()
 
 
718
  for this_edges in all_related_edges:
719
- all_edges.update([tuple(sorted(e)) for e in this_edges])
720
- all_edges = list(all_edges)
 
 
 
 
721
  all_edges_pack = await asyncio.gather(
722
  *[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
723
  )
@@ -828,10 +830,16 @@ async def _find_most_related_entities_from_relationships(
828
  query_param: QueryParam,
829
  knowledge_graph_inst: BaseGraphStorage,
830
  ):
831
- entity_names = set()
 
 
832
  for e in edge_datas:
833
- entity_names.add(e["src_id"])
834
- entity_names.add(e["tgt_id"])
 
 
 
 
835
 
836
  node_datas = await asyncio.gather(
837
  *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
 
662
  all_text_units_lookup = {}
663
  for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
664
  for c_id in this_text_units:
665
+ if c_id not in all_text_units_lookup:
666
+ all_text_units_lookup[c_id] = {
667
+ "data": await text_chunks_db.get_by_id(c_id),
668
+ "order": index,
669
+ "relation_counts": 0,
670
+ }
671
+
672
+ if this_edges:
673
  for e in this_edges:
674
  if (
675
  e[1] in all_one_hop_text_units_lookup
676
  and c_id in all_one_hop_text_units_lookup[e[1]]
677
  ):
678
+ all_text_units_lookup[c_id]["relation_counts"] += 1
 
 
 
 
 
 
 
 
679
 
680
  # Filter out None values and ensure data has content
681
  all_text_units = [
 
710
  all_related_edges = await asyncio.gather(
711
  *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
712
  )
713
+ all_edges = []
714
+ seen = set()
715
+
716
  for this_edges in all_related_edges:
717
+ for e in this_edges:
718
+ sorted_edge = tuple(sorted(e))
719
+ if sorted_edge not in seen:
720
+ seen.add(sorted_edge)
721
+ all_edges.append(sorted_edge)
722
+
723
  all_edges_pack = await asyncio.gather(
724
  *[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
725
  )
 
830
  query_param: QueryParam,
831
  knowledge_graph_inst: BaseGraphStorage,
832
  ):
833
+ entity_names = []
834
+ seen = set()
835
+
836
  for e in edge_datas:
837
+ if e["src_id"] not in seen:
838
+ entity_names.append(e["src_id"])
839
+ seen.add(e["src_id"])
840
+ if e["tgt_id"] not in seen:
841
+ entity_names.append(e["tgt_id"])
842
+ seen.add(e["tgt_id"])
843
 
844
  node_datas = await asyncio.gather(
845
  *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
lightrag/utils.py CHANGED
@@ -290,13 +290,19 @@ def process_combine_contexts(hl, ll):
290
  if list_ll:
291
  list_ll = [",".join(item[1:]) for item in list_ll if item]
292
 
293
- combined_sources_set = set(filter(None, list_hl + list_ll))
 
294
 
295
- combined_sources = [",\t".join(header)]
 
 
 
296
 
297
- for i, item in enumerate(combined_sources_set, start=1):
298
- combined_sources.append(f"{i},\t{item}")
299
 
300
- combined_sources = "\n".join(combined_sources)
 
301
 
302
- return combined_sources
 
 
 
290
  if list_ll:
291
  list_ll = [",".join(item[1:]) for item in list_ll if item]
292
 
293
+ combined_sources = []
294
+ seen = set()
295
 
296
+ for item in list_hl + list_ll:
297
+ if item and item not in seen:
298
+ combined_sources.append(item)
299
+ seen.add(item)
300
 
301
+ combined_sources_result = [",\t".join(header)]
 
302
 
303
+ for i, item in enumerate(combined_sources, start=1):
304
+ combined_sources_result.append(f"{i},\t{item}")
305
 
306
+ combined_sources_result = "\n".join(combined_sources_result)
307
+
308
+ return combined_sources_result
reproduce/Step_1.py CHANGED
@@ -24,7 +24,7 @@ def insert_text(rag, file_path):
24
 
25
 
26
  cls = "agriculture"
27
- WORKING_DIR = "../{cls}"
28
 
29
  if not os.path.exists(WORKING_DIR):
30
  os.mkdir(WORKING_DIR)
 
24
 
25
 
26
  cls = "agriculture"
27
+ WORKING_DIR = f"../{cls}"
28
 
29
  if not os.path.exists(WORKING_DIR):
30
  os.mkdir(WORKING_DIR)