yangdx
commited on
Commit
·
c71b14f
1
Parent(s):
9b765c7
Fix linting
Browse files- lightrag/kg/neo4j_impl.py +30 -16
- lightrag/kg/shared_storage.py +84 -18
- lightrag/operate.py +6 -3
lightrag/kg/neo4j_impl.py
CHANGED
@@ -181,10 +181,10 @@ class Neo4JStorage(BaseGraphStorage):
|
|
181 |
|
182 |
Args:
|
183 |
label: The label to validate
|
184 |
-
|
185 |
Returns:
|
186 |
str: The cleaned label
|
187 |
-
|
188 |
Raises:
|
189 |
ValueError: If label is empty after cleaning
|
190 |
"""
|
@@ -283,7 +283,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|
283 |
query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n"
|
284 |
result = await session.run(query, entity_id=entity_name_label)
|
285 |
try:
|
286 |
-
records = await result.fetch(
|
|
|
|
|
287 |
|
288 |
if len(records) > 1:
|
289 |
logger.warning(
|
@@ -552,6 +554,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
552 |
|
553 |
try:
|
554 |
async with self._driver.session(database=self._DATABASE) as session:
|
|
|
555 |
async def execute_upsert(tx: AsyncManagedTransaction):
|
556 |
query = f"""
|
557 |
MERGE (n:`{label}` {{entity_id: $properties.entity_id}})
|
@@ -562,7 +565,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
562 |
f"Upserted node with label '{label}' and properties: {properties}"
|
563 |
)
|
564 |
await result.consume() # Ensure result is fully consumed
|
565 |
-
|
566 |
await session.execute_write(execute_upsert)
|
567 |
except Exception as e:
|
568 |
logger.error(f"Error during upsert: {str(e)}")
|
@@ -602,18 +605,26 @@ class Neo4JStorage(BaseGraphStorage):
|
|
602 |
"""
|
603 |
result = await session.run(query)
|
604 |
try:
|
605 |
-
records = await result.fetch(
|
606 |
-
|
|
|
|
|
607 |
if not records or records[0]["node_count"] == 0:
|
608 |
-
raise ValueError(
|
609 |
-
|
|
|
|
|
610 |
if records[0]["node_count"] > 1:
|
611 |
-
raise ValueError(
|
612 |
-
|
|
|
|
|
613 |
node = records[0]["n"]
|
614 |
if "entity_id" not in node:
|
615 |
-
raise ValueError(
|
616 |
-
|
|
|
|
|
617 |
return node["entity_id"]
|
618 |
finally:
|
619 |
await result.consume() # Ensure result is fully consumed
|
@@ -656,6 +667,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
656 |
|
657 |
try:
|
658 |
async with self._driver.session(database=self._DATABASE) as session:
|
|
|
659 |
async def execute_upsert(tx: AsyncManagedTransaction):
|
660 |
query = f"""
|
661 |
MATCH (source:`{source_label}` {{entity_id: $source_entity_id}})
|
@@ -666,10 +678,10 @@ class Neo4JStorage(BaseGraphStorage):
|
|
666 |
RETURN r, source, target
|
667 |
"""
|
668 |
result = await tx.run(
|
669 |
-
query,
|
670 |
source_entity_id=source_entity_id,
|
671 |
target_entity_id=target_entity_id,
|
672 |
-
properties=edge_properties
|
673 |
)
|
674 |
try:
|
675 |
records = await result.fetch(100)
|
@@ -681,7 +693,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
681 |
)
|
682 |
finally:
|
683 |
await result.consume() # Ensure result is consumed
|
684 |
-
|
685 |
await session.execute_write(execute_upsert)
|
686 |
except Exception as e:
|
687 |
logger.error(f"Error during edge upsert: {str(e)}")
|
@@ -891,7 +903,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|
891 |
results = await session.run(query, {"node_id": node.id})
|
892 |
|
893 |
# Get all records and release database connection
|
894 |
-
records = await results.fetch(
|
|
|
|
|
895 |
await results.consume() # Ensure results are consumed
|
896 |
|
897 |
# Nodes not connected to start node need to check degree
|
|
|
181 |
|
182 |
Args:
|
183 |
label: The label to validate
|
184 |
+
|
185 |
Returns:
|
186 |
str: The cleaned label
|
187 |
+
|
188 |
Raises:
|
189 |
ValueError: If label is empty after cleaning
|
190 |
"""
|
|
|
283 |
query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n"
|
284 |
result = await session.run(query, entity_id=entity_name_label)
|
285 |
try:
|
286 |
+
records = await result.fetch(
|
287 |
+
2
|
288 |
+
) # Get 2 records for duplication check
|
289 |
|
290 |
if len(records) > 1:
|
291 |
logger.warning(
|
|
|
554 |
|
555 |
try:
|
556 |
async with self._driver.session(database=self._DATABASE) as session:
|
557 |
+
|
558 |
async def execute_upsert(tx: AsyncManagedTransaction):
|
559 |
query = f"""
|
560 |
MERGE (n:`{label}` {{entity_id: $properties.entity_id}})
|
|
|
565 |
f"Upserted node with label '{label}' and properties: {properties}"
|
566 |
)
|
567 |
await result.consume() # Ensure result is fully consumed
|
568 |
+
|
569 |
await session.execute_write(execute_upsert)
|
570 |
except Exception as e:
|
571 |
logger.error(f"Error during upsert: {str(e)}")
|
|
|
605 |
"""
|
606 |
result = await session.run(query)
|
607 |
try:
|
608 |
+
records = await result.fetch(
|
609 |
+
2
|
610 |
+
) # We only need to know if there are 0, 1, or >1 nodes
|
611 |
+
|
612 |
if not records or records[0]["node_count"] == 0:
|
613 |
+
raise ValueError(
|
614 |
+
f"Neo4j: node with label '{node_label}' does not exist"
|
615 |
+
)
|
616 |
+
|
617 |
if records[0]["node_count"] > 1:
|
618 |
+
raise ValueError(
|
619 |
+
f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node"
|
620 |
+
)
|
621 |
+
|
622 |
node = records[0]["n"]
|
623 |
if "entity_id" not in node:
|
624 |
+
raise ValueError(
|
625 |
+
f"Neo4j: node with label '{node_label}' does not have an entity_id property"
|
626 |
+
)
|
627 |
+
|
628 |
return node["entity_id"]
|
629 |
finally:
|
630 |
await result.consume() # Ensure result is fully consumed
|
|
|
667 |
|
668 |
try:
|
669 |
async with self._driver.session(database=self._DATABASE) as session:
|
670 |
+
|
671 |
async def execute_upsert(tx: AsyncManagedTransaction):
|
672 |
query = f"""
|
673 |
MATCH (source:`{source_label}` {{entity_id: $source_entity_id}})
|
|
|
678 |
RETURN r, source, target
|
679 |
"""
|
680 |
result = await tx.run(
|
681 |
+
query,
|
682 |
source_entity_id=source_entity_id,
|
683 |
target_entity_id=target_entity_id,
|
684 |
+
properties=edge_properties,
|
685 |
)
|
686 |
try:
|
687 |
records = await result.fetch(100)
|
|
|
693 |
)
|
694 |
finally:
|
695 |
await result.consume() # Ensure result is consumed
|
696 |
+
|
697 |
await session.execute_write(execute_upsert)
|
698 |
except Exception as e:
|
699 |
logger.error(f"Error during edge upsert: {str(e)}")
|
|
|
903 |
results = await session.run(query, {"node_id": node.id})
|
904 |
|
905 |
# Get all records and release database connection
|
906 |
+
records = await results.fetch(
|
907 |
+
1000
|
908 |
+
) # Max neighbour nodes we can handled
|
909 |
await results.consume() # Ensure results are consumed
|
910 |
|
911 |
# Nodes not connected to start node need to check degree
|
lightrag/kg/shared_storage.py
CHANGED
@@ -11,7 +11,7 @@ def direct_log(message, level="INFO", enable_output: bool = True):
|
|
11 |
"""
|
12 |
Log a message directly to stderr to ensure visibility in all processes,
|
13 |
including the Gunicorn master process.
|
14 |
-
|
15 |
Args:
|
16 |
message: The message to log
|
17 |
level: Log level (default: "INFO")
|
@@ -44,7 +44,13 @@ _graph_db_lock: Optional[LockType] = None
|
|
44 |
class UnifiedLock(Generic[T]):
|
45 |
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
46 |
|
47 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
self._lock = lock
|
49 |
self._is_async = is_async
|
50 |
self._pid = os.getpid() # for debug only
|
@@ -53,27 +59,47 @@ class UnifiedLock(Generic[T]):
|
|
53 |
|
54 |
async def __aenter__(self) -> "UnifiedLock[T]":
|
55 |
try:
|
56 |
-
direct_log(
|
|
|
|
|
|
|
57 |
if self._is_async:
|
58 |
await self._lock.acquire()
|
59 |
else:
|
60 |
self._lock.acquire()
|
61 |
-
direct_log(
|
|
|
|
|
|
|
62 |
return self
|
63 |
except Exception as e:
|
64 |
-
direct_log(
|
|
|
|
|
|
|
|
|
65 |
raise
|
66 |
|
67 |
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
68 |
try:
|
69 |
-
direct_log(
|
|
|
|
|
|
|
70 |
if self._is_async:
|
71 |
self._lock.release()
|
72 |
else:
|
73 |
self._lock.release()
|
74 |
-
direct_log(
|
|
|
|
|
|
|
75 |
except Exception as e:
|
76 |
-
direct_log(
|
|
|
|
|
|
|
|
|
77 |
raise
|
78 |
|
79 |
def __enter__(self) -> "UnifiedLock[T]":
|
@@ -81,12 +107,22 @@ class UnifiedLock(Generic[T]):
|
|
81 |
try:
|
82 |
if self._is_async:
|
83 |
raise RuntimeError("Use 'async with' for shared_storage lock")
|
84 |
-
direct_log(
|
|
|
|
|
|
|
85 |
self._lock.acquire()
|
86 |
-
direct_log(
|
|
|
|
|
|
|
87 |
return self
|
88 |
except Exception as e:
|
89 |
-
direct_log(
|
|
|
|
|
|
|
|
|
90 |
raise
|
91 |
|
92 |
def __exit__(self, exc_type, exc_val, exc_tb):
|
@@ -94,32 +130,62 @@ class UnifiedLock(Generic[T]):
|
|
94 |
try:
|
95 |
if self._is_async:
|
96 |
raise RuntimeError("Use 'async with' for shared_storage lock")
|
97 |
-
direct_log(
|
|
|
|
|
|
|
98 |
self._lock.release()
|
99 |
-
direct_log(
|
|
|
|
|
|
|
100 |
except Exception as e:
|
101 |
-
direct_log(
|
|
|
|
|
|
|
|
|
102 |
raise
|
103 |
|
104 |
|
105 |
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
106 |
"""return unified storage lock for data consistency"""
|
107 |
-
return UnifiedLock(
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
|
110 |
def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
|
111 |
"""return unified storage lock for data consistency"""
|
112 |
-
return UnifiedLock(
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
|
115 |
def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
|
116 |
"""return unified storage lock for data consistency"""
|
117 |
-
return UnifiedLock(
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
|
120 |
def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
|
121 |
"""return unified graph database lock for ensuring atomic operations"""
|
122 |
-
return UnifiedLock(
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
|
125 |
def initialize_share_data(workers: int = 1):
|
|
|
11 |
"""
|
12 |
Log a message directly to stderr to ensure visibility in all processes,
|
13 |
including the Gunicorn master process.
|
14 |
+
|
15 |
Args:
|
16 |
message: The message to log
|
17 |
level: Log level (default: "INFO")
|
|
|
44 |
class UnifiedLock(Generic[T]):
|
45 |
"""Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
|
46 |
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
lock: Union[ProcessLock, asyncio.Lock],
|
50 |
+
is_async: bool,
|
51 |
+
name: str = "unnamed",
|
52 |
+
enable_logging: bool = True,
|
53 |
+
):
|
54 |
self._lock = lock
|
55 |
self._is_async = is_async
|
56 |
self._pid = os.getpid() # for debug only
|
|
|
59 |
|
60 |
async def __aenter__(self) -> "UnifiedLock[T]":
|
61 |
try:
|
62 |
+
direct_log(
|
63 |
+
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
|
64 |
+
enable_output=self._enable_logging,
|
65 |
+
)
|
66 |
if self._is_async:
|
67 |
await self._lock.acquire()
|
68 |
else:
|
69 |
self._lock.acquire()
|
70 |
+
direct_log(
|
71 |
+
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})",
|
72 |
+
enable_output=self._enable_logging,
|
73 |
+
)
|
74 |
return self
|
75 |
except Exception as e:
|
76 |
+
direct_log(
|
77 |
+
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}",
|
78 |
+
level="ERROR",
|
79 |
+
enable_output=self._enable_logging,
|
80 |
+
)
|
81 |
raise
|
82 |
|
83 |
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
84 |
try:
|
85 |
+
direct_log(
|
86 |
+
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})",
|
87 |
+
enable_output=self._enable_logging,
|
88 |
+
)
|
89 |
if self._is_async:
|
90 |
self._lock.release()
|
91 |
else:
|
92 |
self._lock.release()
|
93 |
+
direct_log(
|
94 |
+
f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})",
|
95 |
+
enable_output=self._enable_logging,
|
96 |
+
)
|
97 |
except Exception as e:
|
98 |
+
direct_log(
|
99 |
+
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}",
|
100 |
+
level="ERROR",
|
101 |
+
enable_output=self._enable_logging,
|
102 |
+
)
|
103 |
raise
|
104 |
|
105 |
def __enter__(self) -> "UnifiedLock[T]":
|
|
|
107 |
try:
|
108 |
if self._is_async:
|
109 |
raise RuntimeError("Use 'async with' for shared_storage lock")
|
110 |
+
direct_log(
|
111 |
+
f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)",
|
112 |
+
enable_output=self._enable_logging,
|
113 |
+
)
|
114 |
self._lock.acquire()
|
115 |
+
direct_log(
|
116 |
+
f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)",
|
117 |
+
enable_output=self._enable_logging,
|
118 |
+
)
|
119 |
return self
|
120 |
except Exception as e:
|
121 |
+
direct_log(
|
122 |
+
f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}",
|
123 |
+
level="ERROR",
|
124 |
+
enable_output=self._enable_logging,
|
125 |
+
)
|
126 |
raise
|
127 |
|
128 |
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
130 |
try:
|
131 |
if self._is_async:
|
132 |
raise RuntimeError("Use 'async with' for shared_storage lock")
|
133 |
+
direct_log(
|
134 |
+
f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)",
|
135 |
+
enable_output=self._enable_logging,
|
136 |
+
)
|
137 |
self._lock.release()
|
138 |
+
direct_log(
|
139 |
+
f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)",
|
140 |
+
enable_output=self._enable_logging,
|
141 |
+
)
|
142 |
except Exception as e:
|
143 |
+
direct_log(
|
144 |
+
f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}",
|
145 |
+
level="ERROR",
|
146 |
+
enable_output=self._enable_logging,
|
147 |
+
)
|
148 |
raise
|
149 |
|
150 |
|
151 |
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
152 |
"""return unified storage lock for data consistency"""
|
153 |
+
return UnifiedLock(
|
154 |
+
lock=_internal_lock,
|
155 |
+
is_async=not is_multiprocess,
|
156 |
+
name="internal_lock",
|
157 |
+
enable_logging=enable_logging,
|
158 |
+
)
|
159 |
|
160 |
|
161 |
def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
|
162 |
"""return unified storage lock for data consistency"""
|
163 |
+
return UnifiedLock(
|
164 |
+
lock=_storage_lock,
|
165 |
+
is_async=not is_multiprocess,
|
166 |
+
name="storage_lock",
|
167 |
+
enable_logging=enable_logging,
|
168 |
+
)
|
169 |
|
170 |
|
171 |
def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
|
172 |
"""return unified storage lock for data consistency"""
|
173 |
+
return UnifiedLock(
|
174 |
+
lock=_pipeline_status_lock,
|
175 |
+
is_async=not is_multiprocess,
|
176 |
+
name="pipeline_status_lock",
|
177 |
+
enable_logging=enable_logging,
|
178 |
+
)
|
179 |
|
180 |
|
181 |
def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
|
182 |
"""return unified graph database lock for ensuring atomic operations"""
|
183 |
+
return UnifiedLock(
|
184 |
+
lock=_graph_db_lock,
|
185 |
+
is_async=not is_multiprocess,
|
186 |
+
name="graph_db_lock",
|
187 |
+
enable_logging=enable_logging,
|
188 |
+
)
|
189 |
|
190 |
|
191 |
def initialize_share_data(workers: int = 1):
|
lightrag/operate.py
CHANGED
@@ -522,8 +522,9 @@ async def extract_entities(
|
|
522 |
maybe_edges[tuple(sorted(k))].extend(v)
|
523 |
|
524 |
from .kg.shared_storage import get_graph_db_lock
|
525 |
-
|
526 |
-
|
|
|
527 |
# Ensure that nodes and edges are merged and upserted atomically
|
528 |
async with graph_db_lock:
|
529 |
all_entities_data = await asyncio.gather(
|
@@ -535,7 +536,9 @@ async def extract_entities(
|
|
535 |
|
536 |
all_relationships_data = await asyncio.gather(
|
537 |
*[
|
538 |
-
_merge_edges_then_upsert(
|
|
|
|
|
539 |
for k, v in maybe_edges.items()
|
540 |
]
|
541 |
)
|
|
|
522 |
maybe_edges[tuple(sorted(k))].extend(v)
|
523 |
|
524 |
from .kg.shared_storage import get_graph_db_lock
|
525 |
+
|
526 |
+
graph_db_lock = get_graph_db_lock(enable_logging=True)
|
527 |
+
|
528 |
# Ensure that nodes and edges are merged and upserted atomically
|
529 |
async with graph_db_lock:
|
530 |
all_entities_data = await asyncio.gather(
|
|
|
536 |
|
537 |
all_relationships_data = await asyncio.gather(
|
538 |
*[
|
539 |
+
_merge_edges_then_upsert(
|
540 |
+
k[0], k[1], v, knowledge_graph_inst, global_config
|
541 |
+
)
|
542 |
for k, v in maybe_edges.items()
|
543 |
]
|
544 |
)
|