ArnoChen commited on
Commit
1a28357
·
1 Parent(s): 056dbb4

use namespace as neo4j database name

Browse files
lightrag/kg/neo4j_impl.py CHANGED
@@ -1,6 +1,7 @@
1
  import asyncio
2
  import inspect
3
  import os
 
4
  from dataclasses import dataclass
5
  from typing import Any, Union, Tuple, List, Dict
6
  import pipmaster as pm
@@ -22,7 +23,7 @@ from tenacity import (
22
  retry_if_exception_type,
23
  )
24
 
25
- from lightrag.utils import logger
26
  from ..base import BaseGraphStorage
27
 
28
 
@@ -45,50 +46,68 @@ class Neo4JStorage(BaseGraphStorage):
45
  PASSWORD = os.environ["NEO4J_PASSWORD"]
46
  MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
47
  DATABASE = os.environ.get(
48
- "NEO4J_DATABASE"
49
- ) # If this param is None, the home database will be used. If it is not None, the specified database will be used.
50
- self._DATABASE = DATABASE
51
  self._driver: AsyncDriver = AsyncGraphDatabase.driver(
52
  URI, auth=(USERNAME, PASSWORD)
53
  )
54
- _database_name = "home database" if DATABASE is None else f"database {DATABASE}"
 
55
  with GraphDatabase.driver(
56
  URI,
57
  auth=(USERNAME, PASSWORD),
58
  max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
59
  ) as _sync_driver:
60
- try:
61
- with _sync_driver.session(database=DATABASE) as session:
62
- try:
63
- session.run("MATCH (n) RETURN n LIMIT 0")
64
- logger.info(f"Connected to {DATABASE} at {URI}")
65
- except neo4jExceptions.ServiceUnavailable as e:
66
- logger.error(
67
- f"{DATABASE} at {URI} is not available".capitalize()
68
- )
69
- raise e
70
- except neo4jExceptions.AuthError as e:
71
- logger.error(f"Authentication failed for {DATABASE} at {URI}")
72
- raise e
73
- except neo4jExceptions.ClientError as e:
74
- if e.code == "Neo.ClientError.Database.DatabaseNotFound":
75
- logger.info(
76
- f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize()
77
- )
78
  try:
79
- with _sync_driver.session() as session:
80
- session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS")
81
- logger.info(f"{DATABASE} at {URI} created".capitalize())
 
 
 
 
 
 
 
 
 
 
82
  except neo4jExceptions.ClientError as e:
83
- if (
84
- e.code
85
- == "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
86
- ):
87
- logger.warning(
88
- "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead."
89
  )
90
- logger.error(f"Failed to create {DATABASE} at {URI}")
91
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def __post_init__(self):
94
  self._node_embed_algorithms = {
@@ -117,7 +136,7 @@ class Neo4JStorage(BaseGraphStorage):
117
  result = await session.run(query)
118
  single_result = await result.single()
119
  logger.debug(
120
- f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
121
  )
122
  return single_result["node_exists"]
123
 
@@ -133,7 +152,7 @@ class Neo4JStorage(BaseGraphStorage):
133
  result = await session.run(query)
134
  single_result = await result.single()
135
  logger.debug(
136
- f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
137
  )
138
  return single_result["edgeExists"]
139
 
 
1
  import asyncio
2
  import inspect
3
  import os
4
+ import re
5
  from dataclasses import dataclass
6
  from typing import Any, Union, Tuple, List, Dict
7
  import pipmaster as pm
 
23
  retry_if_exception_type,
24
  )
25
 
26
+ from ..utils import logger
27
  from ..base import BaseGraphStorage
28
 
29
 
 
46
  PASSWORD = os.environ["NEO4J_PASSWORD"]
47
  MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800)
48
  DATABASE = os.environ.get(
49
+ "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
50
+ )
 
51
  self._driver: AsyncDriver = AsyncGraphDatabase.driver(
52
  URI, auth=(USERNAME, PASSWORD)
53
  )
54
+
55
+ # Try to connect to the database
56
  with GraphDatabase.driver(
57
  URI,
58
  auth=(USERNAME, PASSWORD),
59
  max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
60
  ) as _sync_driver:
61
+ for database in (DATABASE, None):
62
+ self._DATABASE = database
63
+ connected = False
64
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  try:
66
+ with _sync_driver.session(database=database) as session:
67
+ try:
68
+ session.run("MATCH (n) RETURN n LIMIT 0")
69
+ logger.info(f"Connected to {database} at {URI}")
70
+ connected = True
71
+ except neo4jExceptions.ServiceUnavailable as e:
72
+ logger.error(
73
+ f"{database} at {URI} is not available".capitalize()
74
+ )
75
+ raise e
76
+ except neo4jExceptions.AuthError as e:
77
+ logger.error(f"Authentication failed for {database} at {URI}")
78
+ raise e
79
  except neo4jExceptions.ClientError as e:
80
+ if e.code == "Neo.ClientError.Database.DatabaseNotFound":
81
+ logger.info(
82
+ f"{database} at {URI} not found. Try to create specified database.".capitalize()
 
 
 
83
  )
84
+ try:
85
+ with _sync_driver.session() as session:
86
+ session.run(
87
+ f"CREATE DATABASE `{database}` IF NOT EXISTS"
88
+ )
89
+ logger.info(f"{database} at {URI} created".capitalize())
90
+ connected = True
91
+ except (
92
+ neo4jExceptions.ClientError,
93
+ neo4jExceptions.DatabaseError,
94
+ ) as e:
95
+ if (
96
+ e.code
97
+ == "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
98
+ ) or (
99
+ e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
100
+ ):
101
+ if database is not None:
102
+ logger.warning(
103
+ "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
104
+ )
105
+ if database is None:
106
+ logger.error(f"Failed to create {database} at {URI}")
107
+ raise e
108
+
109
+ if connected:
110
+ break
111
 
112
  def __post_init__(self):
113
  self._node_embed_algorithms = {
 
136
  result = await session.run(query)
137
  single_result = await result.single()
138
  logger.debug(
139
+ f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}"
140
  )
141
  return single_result["node_exists"]
142
 
 
152
  result = await session.run(query)
153
  single_result = await result.single()
154
  logger.debug(
155
+ f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}"
156
  )
157
  return single_result["edgeExists"]
158
 
lightrag/kg/oracle_impl.py CHANGED
@@ -257,7 +257,8 @@ class OracleKVStorage(BaseKVStorage):
257
  async def filter_keys(self, keys: list[str]) -> set[str]:
258
  """Return keys that don't exist in storage"""
259
  SQL = SQL_TEMPLATES["filter_keys"].format(
260
- table_name=namespace_to_table_name(self.namespace), ids=",".join([f"'{id}'" for id in keys])
 
261
  )
262
  params = {"workspace": self.db.workspace}
263
  res = await self.db.query(SQL, params, multirows=True)
@@ -330,7 +331,9 @@ class OracleKVStorage(BaseKVStorage):
330
  return None
331
 
332
  async def change_status(self, id: str, status: str):
333
- SQL = SQL_TEMPLATES["change_status"].format(table_name=namespace_to_table_name(self.namespace))
 
 
334
  params = {"workspace": self.db.workspace, "id": id, "status": status}
335
  await self.db.execute(SQL, params)
336
 
@@ -623,6 +626,7 @@ N_T = {
623
  NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
624
  }
625
 
 
626
  def namespace_to_table_name(namespace: str) -> str:
627
  for k, v in N_T.items():
628
  if is_namespace(namespace, k):
 
257
  async def filter_keys(self, keys: list[str]) -> set[str]:
258
  """Return keys that don't exist in storage"""
259
  SQL = SQL_TEMPLATES["filter_keys"].format(
260
+ table_name=namespace_to_table_name(self.namespace),
261
+ ids=",".join([f"'{id}'" for id in keys]),
262
  )
263
  params = {"workspace": self.db.workspace}
264
  res = await self.db.query(SQL, params, multirows=True)
 
331
  return None
332
 
333
  async def change_status(self, id: str, status: str):
334
+ SQL = SQL_TEMPLATES["change_status"].format(
335
+ table_name=namespace_to_table_name(self.namespace)
336
+ )
337
  params = {"workspace": self.db.workspace, "id": id, "status": status}
338
  await self.db.execute(SQL, params)
339
 
 
626
  NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
627
  }
628
 
629
+
630
  def namespace_to_table_name(namespace: str) -> str:
631
  for k, v in N_T.items():
632
  if is_namespace(namespace, k):
lightrag/lightrag.py CHANGED
@@ -236,7 +236,9 @@ class LightRAG:
236
  )
237
 
238
  self.llm_response_cache = self.key_string_value_json_storage_cls(
239
- namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
 
 
240
  embedding_func=self.embedding_func,
241
  )
242
 
@@ -244,15 +246,21 @@ class LightRAG:
244
  # add embedding func by walter
245
  ####
246
  self.full_docs = self.key_string_value_json_storage_cls(
247
- namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS),
 
 
248
  embedding_func=self.embedding_func,
249
  )
250
  self.text_chunks = self.key_string_value_json_storage_cls(
251
- namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS),
 
 
252
  embedding_func=self.embedding_func,
253
  )
254
  self.chunk_entity_relation_graph = self.graph_storage_cls(
255
- namespace=make_namespace(self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION),
 
 
256
  embedding_func=self.embedding_func,
257
  )
258
  ####
@@ -260,17 +268,23 @@ class LightRAG:
260
  ####
261
 
262
  self.entities_vdb = self.vector_db_storage_cls(
263
- namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES),
 
 
264
  embedding_func=self.embedding_func,
265
  meta_fields={"entity_name"},
266
  )
267
  self.relationships_vdb = self.vector_db_storage_cls(
268
- namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS),
 
 
269
  embedding_func=self.embedding_func,
270
  meta_fields={"src_id", "tgt_id"},
271
  )
272
  self.chunks_vdb = self.vector_db_storage_cls(
273
- namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS),
 
 
274
  embedding_func=self.embedding_func,
275
  )
276
 
@@ -280,7 +294,9 @@ class LightRAG:
280
  hashing_kv = self.llm_response_cache
281
  else:
282
  hashing_kv = self.key_string_value_json_storage_cls(
283
- namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
 
 
284
  embedding_func=self.embedding_func,
285
  )
286
 
@@ -931,7 +947,9 @@ class LightRAG:
931
  if self.llm_response_cache
932
  and hasattr(self.llm_response_cache, "global_config")
933
  else self.key_string_value_json_storage_cls(
934
- namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
 
 
935
  global_config=asdict(self),
936
  embedding_func=self.embedding_func,
937
  ),
@@ -948,7 +966,9 @@ class LightRAG:
948
  if self.llm_response_cache
949
  and hasattr(self.llm_response_cache, "global_config")
950
  else self.key_string_value_json_storage_cls(
951
- namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
 
 
952
  global_config=asdict(self),
953
  embedding_func=self.embedding_func,
954
  ),
@@ -967,7 +987,9 @@ class LightRAG:
967
  if self.llm_response_cache
968
  and hasattr(self.llm_response_cache, "global_config")
969
  else self.key_string_value_json_storage_cls(
970
- namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
 
 
971
  global_config=asdict(self),
972
  embedding_func=self.embedding_func,
973
  ),
@@ -1008,7 +1030,9 @@ class LightRAG:
1008
  global_config=asdict(self),
1009
  hashing_kv=self.llm_response_cache
1010
  or self.key_string_value_json_storage_cls(
1011
- namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
 
 
1012
  global_config=asdict(self),
1013
  embedding_func=self.embedding_func,
1014
  ),
@@ -1039,7 +1063,9 @@ class LightRAG:
1039
  if self.llm_response_cache
1040
  and hasattr(self.llm_response_cache, "global_config")
1041
  else self.key_string_value_json_storage_cls(
1042
- namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
 
 
1043
  global_config=asdict(self),
1044
  embedding_func=self.embedding_funcne,
1045
  ),
@@ -1055,7 +1081,9 @@ class LightRAG:
1055
  if self.llm_response_cache
1056
  and hasattr(self.llm_response_cache, "global_config")
1057
  else self.key_string_value_json_storage_cls(
1058
- namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
 
 
1059
  global_config=asdict(self),
1060
  embedding_func=self.embedding_func,
1061
  ),
@@ -1074,7 +1102,9 @@ class LightRAG:
1074
  if self.llm_response_cache
1075
  and hasattr(self.llm_response_cache, "global_config")
1076
  else self.key_string_value_json_storage_cls(
1077
- namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
 
 
1078
  global_config=asdict(self),
1079
  embedding_func=self.embedding_func,
1080
  ),
 
236
  )
237
 
238
  self.llm_response_cache = self.key_string_value_json_storage_cls(
239
+ namespace=make_namespace(
240
+ self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
241
+ ),
242
  embedding_func=self.embedding_func,
243
  )
244
 
 
246
  # add embedding func by walter
247
  ####
248
  self.full_docs = self.key_string_value_json_storage_cls(
249
+ namespace=make_namespace(
250
+ self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
251
+ ),
252
  embedding_func=self.embedding_func,
253
  )
254
  self.text_chunks = self.key_string_value_json_storage_cls(
255
+ namespace=make_namespace(
256
+ self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
257
+ ),
258
  embedding_func=self.embedding_func,
259
  )
260
  self.chunk_entity_relation_graph = self.graph_storage_cls(
261
+ namespace=make_namespace(
262
+ self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
263
+ ),
264
  embedding_func=self.embedding_func,
265
  )
266
  ####
 
268
  ####
269
 
270
  self.entities_vdb = self.vector_db_storage_cls(
271
+ namespace=make_namespace(
272
+ self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
273
+ ),
274
  embedding_func=self.embedding_func,
275
  meta_fields={"entity_name"},
276
  )
277
  self.relationships_vdb = self.vector_db_storage_cls(
278
+ namespace=make_namespace(
279
+ self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
280
+ ),
281
  embedding_func=self.embedding_func,
282
  meta_fields={"src_id", "tgt_id"},
283
  )
284
  self.chunks_vdb = self.vector_db_storage_cls(
285
+ namespace=make_namespace(
286
+ self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
287
+ ),
288
  embedding_func=self.embedding_func,
289
  )
290
 
 
294
  hashing_kv = self.llm_response_cache
295
  else:
296
  hashing_kv = self.key_string_value_json_storage_cls(
297
+ namespace=make_namespace(
298
+ self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
299
+ ),
300
  embedding_func=self.embedding_func,
301
  )
302
 
 
947
  if self.llm_response_cache
948
  and hasattr(self.llm_response_cache, "global_config")
949
  else self.key_string_value_json_storage_cls(
950
+ namespace=make_namespace(
951
+ self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
952
+ ),
953
  global_config=asdict(self),
954
  embedding_func=self.embedding_func,
955
  ),
 
966
  if self.llm_response_cache
967
  and hasattr(self.llm_response_cache, "global_config")
968
  else self.key_string_value_json_storage_cls(
969
+ namespace=make_namespace(
970
+ self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
971
+ ),
972
  global_config=asdict(self),
973
  embedding_func=self.embedding_func,
974
  ),
 
987
  if self.llm_response_cache
988
  and hasattr(self.llm_response_cache, "global_config")
989
  else self.key_string_value_json_storage_cls(
990
+ namespace=make_namespace(
991
+ self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
992
+ ),
993
  global_config=asdict(self),
994
  embedding_func=self.embedding_func,
995
  ),
 
1030
  global_config=asdict(self),
1031
  hashing_kv=self.llm_response_cache
1032
  or self.key_string_value_json_storage_cls(
1033
+ namespace=make_namespace(
1034
+ self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1035
+ ),
1036
  global_config=asdict(self),
1037
  embedding_func=self.embedding_func,
1038
  ),
 
1063
  if self.llm_response_cache
1064
  and hasattr(self.llm_response_cache, "global_config")
1065
  else self.key_string_value_json_storage_cls(
1066
+ namespace=make_namespace(
1067
+ self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1068
+ ),
1069
  global_config=asdict(self),
1070
  embedding_func=self.embedding_funcne,
1071
  ),
 
1081
  if self.llm_response_cache
1082
  and hasattr(self.llm_response_cache, "global_config")
1083
  else self.key_string_value_json_storage_cls(
1084
+ namespace=make_namespace(
1085
+ self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1086
+ ),
1087
  global_config=asdict(self),
1088
  embedding_func=self.embedding_func,
1089
  ),
 
1102
  if self.llm_response_cache
1103
  and hasattr(self.llm_response_cache, "global_config")
1104
  else self.key_string_value_json_storage_cls(
1105
+ namespace=make_namespace(
1106
+ self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1107
+ ),
1108
  global_config=asdict(self),
1109
  embedding_func=self.embedding_func,
1110
  ),