ArnoChen commited on
Commit
dc3571b
·
1 Parent(s): 316c28a

use KnowledgeGraph typed dict for graph API response

Browse files
lightrag/api/lightrag_server.py CHANGED
@@ -1424,8 +1424,8 @@ def create_app(args):
1424
 
1425
  # query all graph
1426
  @app.get("/graphs")
1427
- async def get_graphs(label: str):
1428
- return await rag.get_graphs(nodel_label=label, max_depth=100)
1429
 
1430
  # Add Ollama API routes
1431
  ollama_api = OllamaAPI(rag)
 
1424
 
1425
  # query all graph
1426
  @app.get("/graphs")
1427
+ async def get_knowledge_graph(label: str):
1428
+ return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
1429
 
1430
  # Add Ollama API routes
1431
  ollama_api = OllamaAPI(rag)
lightrag/base.py CHANGED
@@ -13,6 +13,7 @@ from typing import (
13
  import numpy as np
14
 
15
  from .utils import EmbeddingFunc
 
16
 
17
 
18
  class TextChunkSchema(TypedDict):
@@ -175,7 +176,7 @@ class BaseGraphStorage(StorageNameSpace):
175
 
176
  async def get_knowledge_graph(
177
  self, node_label: str, max_depth: int = 5
178
- ) -> dict[str, list[dict]]:
179
  raise NotImplementedError
180
 
181
 
 
13
  import numpy as np
14
 
15
  from .utils import EmbeddingFunc
16
+ from .types import KnowledgeGraph
17
 
18
 
19
  class TextChunkSchema(TypedDict):
 
176
 
177
  async def get_knowledge_graph(
178
  self, node_label: str, max_depth: int = 5
179
+ ) -> KnowledgeGraph:
180
  raise NotImplementedError
181
 
182
 
lightrag/kg/neo4j_impl.py CHANGED
@@ -25,6 +25,7 @@ from tenacity import (
25
 
26
  from ..utils import logger
27
  from ..base import BaseGraphStorage
 
28
 
29
 
30
  @dataclass
@@ -44,7 +45,8 @@ class Neo4JStorage(BaseGraphStorage):
44
  URI = os.environ["NEO4J_URI"]
45
  USERNAME = os.environ["NEO4J_USERNAME"]
46
  PASSWORD = os.environ["NEO4J_PASSWORD"]
47
- MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
 
48
  DATABASE = os.environ.get(
49
  "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
50
  )
@@ -74,19 +76,22 @@ class Neo4JStorage(BaseGraphStorage):
74
  )
75
  raise e
76
  except neo4jExceptions.AuthError as e:
77
- logger.error(f"Authentication failed for {database} at {URI}")
 
78
  raise e
79
  except neo4jExceptions.ClientError as e:
80
  if e.code == "Neo.ClientError.Database.DatabaseNotFound":
81
  logger.info(
82
- f"{database} at {URI} not found. Try to create specified database.".capitalize()
 
83
  )
84
  try:
85
  with _sync_driver.session() as session:
86
  session.run(
87
  f"CREATE DATABASE `{database}` IF NOT EXISTS"
88
  )
89
- logger.info(f"{database} at {URI} created".capitalize())
 
90
  connected = True
91
  except (
92
  neo4jExceptions.ClientError,
@@ -103,7 +108,8 @@ class Neo4JStorage(BaseGraphStorage):
103
  "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
104
  )
105
  if database is None:
106
- logger.error(f"Failed to create {database} at {URI}")
 
107
  raise e
108
 
109
  if connected:
@@ -365,7 +371,7 @@ class Neo4JStorage(BaseGraphStorage):
365
 
366
  async def get_knowledge_graph(
367
  self, node_label: str, max_depth: int = 5
368
- ) -> Dict[str, List[Dict]]:
369
  """
370
  Get complete connected subgraph for specified node (including the starting node itself)
371
 
@@ -376,7 +382,7 @@ class Neo4JStorage(BaseGraphStorage):
376
  4. Add depth control
377
  """
378
  label = node_label.strip('"')
379
- result = {"nodes": [], "edges": []}
380
  seen_nodes = set()
381
  seen_edges = set()
382
 
@@ -395,7 +401,8 @@ class Neo4JStorage(BaseGraphStorage):
395
  validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
396
  validate_result = await session.run(validate_query)
397
  if not await validate_result.single():
398
- logger.warning(f"Starting node {label} does not exist!")
 
399
  return result
400
 
401
  # Optimized query (including direction handling and self-loops)
@@ -420,11 +427,11 @@ class Neo4JStorage(BaseGraphStorage):
420
  # Use node ID + label combination as unique identifier
421
  node_id = node.id
422
  if node_id not in seen_nodes:
423
- node_data = {}
424
- node_data["labels"] = list(node.labels) # Keep all labels
425
- node_data["id"] = f"{node_id}"
426
- node_data["properties"] = dict(node)
427
- result["nodes"].append(node_data)
428
  seen_nodes.add(node_id)
429
 
430
  # Handle relationships (including direction information)
@@ -433,21 +440,17 @@ class Neo4JStorage(BaseGraphStorage):
433
  if edge_id not in seen_edges:
434
  start = rel.start_node
435
  end = rel.end_node
436
- edge_data = {}
437
- edge_data.update(
438
- {
439
- "source": f"{start.id}",
440
- "target": f"{end.id}",
441
- "type": rel.type,
442
- "id": f"{edge_id}",
443
- "properties": dict(rel),
444
- }
445
- )
446
- result["edges"].append(edge_data)
447
  seen_edges.add(edge_id)
448
 
449
  logger.info(
450
- f"Subgraph query successful | Node count: {len(result['nodes'])} | Edge count: {len(result['edges'])}"
451
  )
452
 
453
  except neo4jExceptions.ClientError as e:
 
25
 
26
  from ..utils import logger
27
  from ..base import BaseGraphStorage
28
+ from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
29
 
30
 
31
  @dataclass
 
45
  URI = os.environ["NEO4J_URI"]
46
  USERNAME = os.environ["NEO4J_USERNAME"]
47
  PASSWORD = os.environ["NEO4J_PASSWORD"]
48
+ MAX_CONNECTION_POOL_SIZE = os.environ.get(
49
+ "NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
50
  DATABASE = os.environ.get(
51
  "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
52
  )
 
76
  )
77
  raise e
78
  except neo4jExceptions.AuthError as e:
79
+ logger.error(
80
+ f"Authentication failed for {database} at {URI}")
81
  raise e
82
  except neo4jExceptions.ClientError as e:
83
  if e.code == "Neo.ClientError.Database.DatabaseNotFound":
84
  logger.info(
85
+ f"{database} at {URI} not found. Try to create specified database.".capitalize(
86
+ )
87
  )
88
  try:
89
  with _sync_driver.session() as session:
90
  session.run(
91
  f"CREATE DATABASE `{database}` IF NOT EXISTS"
92
  )
93
+ logger.info(
94
+ f"{database} at {URI} created".capitalize())
95
  connected = True
96
  except (
97
  neo4jExceptions.ClientError,
 
108
  "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
109
  )
110
  if database is None:
111
+ logger.error(
112
+ f"Failed to create {database} at {URI}")
113
  raise e
114
 
115
  if connected:
 
371
 
372
  async def get_knowledge_graph(
373
  self, node_label: str, max_depth: int = 5
374
+ ) -> KnowledgeGraph:
375
  """
376
  Get complete connected subgraph for specified node (including the starting node itself)
377
 
 
382
  4. Add depth control
383
  """
384
  label = node_label.strip('"')
385
+ result = KnowledgeGraph()
386
  seen_nodes = set()
387
  seen_edges = set()
388
 
 
401
  validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
402
  validate_result = await session.run(validate_query)
403
  if not await validate_result.single():
404
+ logger.warning(
405
+ f"Starting node {label} does not exist!")
406
  return result
407
 
408
  # Optimized query (including direction handling and self-loops)
 
427
  # Use node ID + label combination as unique identifier
428
  node_id = node.id
429
  if node_id not in seen_nodes:
430
+ result.nodes.append(KnowledgeGraphNode(
431
+ id=f"{node_id}",
432
+ labels=list(node.labels),
433
+ properties=dict(node),
434
+ ))
435
  seen_nodes.add(node_id)
436
 
437
  # Handle relationships (including direction information)
 
440
  if edge_id not in seen_edges:
441
  start = rel.start_node
442
  end = rel.end_node
443
+ result.edges.append(KnowledgeGraphEdge(
444
+ id=f"{edge_id}",
445
+ type=rel.type,
446
+ source=f"{start.id}",
447
+ target=f"{end.id}",
448
+ properties=dict(rel),
449
+ ))
 
 
 
 
450
  seen_edges.add(edge_id)
451
 
452
  logger.info(
453
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
454
  )
455
 
456
  except neo4jExceptions.ClientError as e:
lightrag/lightrag.py CHANGED
@@ -34,6 +34,7 @@ from .utils import (
34
  logger,
35
  set_logger,
36
  )
 
37
 
38
  STORAGES = {
39
  "NetworkXStorage": ".kg.networkx_impl",
@@ -385,7 +386,7 @@ class LightRAG:
385
  text = await self.chunk_entity_relation_graph.get_all_labels()
386
  return text
387
 
388
- async def get_graphs(self, nodel_label: str, max_depth: int):
389
  return await self.chunk_entity_relation_graph.get_knowledge_graph(
390
  node_label=nodel_label, max_depth=max_depth
391
  )
 
34
  logger,
35
  set_logger,
36
  )
37
+ from .types import KnowledgeGraph
38
 
39
  STORAGES = {
40
  "NetworkXStorage": ".kg.networkx_impl",
 
386
  text = await self.chunk_entity_relation_graph.get_all_labels()
387
  return text
388
 
389
+ async def get_knowledge_graph(self, nodel_label: str, max_depth: int) -> KnowledgeGraph:
390
  return await self.chunk_entity_relation_graph.get_knowledge_graph(
391
  node_label=nodel_label, max_depth=max_depth
392
  )
lightrag/types.py CHANGED
@@ -1,7 +1,26 @@
1
  from pydantic import BaseModel
2
- from typing import List
3
 
4
 
5
  class GPTKeywordExtractionFormat(BaseModel):
6
  high_level_keywords: List[str]
7
  low_level_keywords: List[str]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pydantic import BaseModel
2
+ from typing import List, Dict, Any
3
 
4
 
5
  class GPTKeywordExtractionFormat(BaseModel):
6
  high_level_keywords: List[str]
7
  low_level_keywords: List[str]
8
+
9
+
10
+ class KnowledgeGraphNode(BaseModel):
11
+ id: str
12
+ labels: List[str]
13
+ properties: Dict[str, Any] # anything else goes here
14
+
15
+
16
+ class KnowledgeGraphEdge(BaseModel):
17
+ id: str
18
+ type: str
19
+ source: str # id of source node
20
+ target: str # id of target node
21
+ properties: Dict[str, Any] # anything else goes here
22
+
23
+
24
+ class KnowledgeGraph(BaseModel):
25
+ nodes: List[KnowledgeGraphNode] = []
26
+ edges: List[KnowledgeGraphEdge] = []