ArnoChen commited on
Commit
e1a2519
·
1 Parent(s): 65fd96e

improve conditional checks for db instance

Browse files
lightrag/kg/mongo_impl.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from dataclasses import dataclass
3
  import numpy as np
4
  import configparser
5
  from tqdm.asyncio import tqdm as tqdm_async
@@ -27,7 +27,11 @@ if not pm.is_installed("motor"):
27
  pm.install("motor")
28
 
29
  try:
30
- from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
 
 
 
 
31
  from pymongo.operations import SearchIndexModel
32
  from pymongo.errors import PyMongoError
33
  except ImportError as e:
@@ -79,19 +83,23 @@ class ClientManager:
79
  @final
80
  @dataclass
81
  class MongoKVStorage(BaseKVStorage):
 
 
 
82
  def __post_init__(self):
83
  self._collection_name = self.namespace
84
 
85
  async def initialize(self):
86
- if not hasattr(self, "db") or self.db is None:
87
  self.db = await ClientManager.get_client()
88
  self._data = await get_or_create_collection(self.db, self._collection_name)
89
  logger.debug(f"Use MongoDB as KV {self._collection_name}")
90
 
91
  async def finalize(self):
92
- if hasattr(self, "db") and self.db is not None:
93
  await ClientManager.release_client(self.db)
94
  self.db = None
 
95
 
96
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
97
  return await self._data.find_one({"_id": id})
@@ -148,19 +156,23 @@ class MongoKVStorage(BaseKVStorage):
148
  @final
149
  @dataclass
150
  class MongoDocStatusStorage(DocStatusStorage):
 
 
 
151
  def __post_init__(self):
152
  self._collection_name = self.namespace
153
 
154
  async def initialize(self):
155
- if not hasattr(self, "db") or self.db is None:
156
  self.db = await ClientManager.get_client()
157
  self._data = await get_or_create_collection(self.db, self._collection_name)
158
  logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
159
 
160
  async def finalize(self):
161
- if hasattr(self, "db") and self.db is not None:
162
  await ClientManager.release_client(self.db)
163
  self.db = None
 
164
 
165
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
166
  return await self._data.find_one({"_id": id})
@@ -221,9 +233,12 @@ class MongoDocStatusStorage(DocStatusStorage):
221
  @dataclass
222
  class MongoGraphStorage(BaseGraphStorage):
223
  """
224
- A concrete implementation using MongoDBs $graphLookup to demonstrate multi-hop queries.
225
  """
226
 
 
 
 
227
  def __init__(self, namespace, global_config, embedding_func):
228
  super().__init__(
229
  namespace=namespace,
@@ -233,7 +248,7 @@ class MongoGraphStorage(BaseGraphStorage):
233
  self._collection_name = self.namespace
234
 
235
  async def initialize(self):
236
- if not hasattr(self, "db") or self.db is None:
237
  self.db = await ClientManager.get_client()
238
  self.collection = await get_or_create_collection(
239
  self.db, self._collection_name
@@ -241,9 +256,10 @@ class MongoGraphStorage(BaseGraphStorage):
241
  logger.debug(f"Use MongoDB as KG {self._collection_name}")
242
 
243
  async def finalize(self):
244
- if hasattr(self, "db") and self.db is not None:
245
  await ClientManager.release_client(self.db)
246
  self.db = None
 
247
 
248
  #
249
  # -------------------------------------------------------------------------
@@ -782,6 +798,9 @@ class MongoGraphStorage(BaseGraphStorage):
782
  @final
783
  @dataclass
784
  class MongoVectorDBStorage(BaseVectorStorage):
 
 
 
785
  def __post_init__(self):
786
  kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
787
  cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -794,7 +813,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
794
  self._max_batch_size = self.global_config["embedding_batch_num"]
795
 
796
  async def initialize(self):
797
- if not hasattr(self, "db") or self.db is None:
798
  self.db = await ClientManager.get_client()
799
  self._data = await get_or_create_collection(self.db, self._collection_name)
800
 
@@ -804,9 +823,10 @@ class MongoVectorDBStorage(BaseVectorStorage):
804
  logger.debug(f"Use MongoDB as VDB {self._collection_name}")
805
 
806
  async def finalize(self):
807
- if hasattr(self, "db") and self.db is not None:
808
  await ClientManager.release_client(self.db)
809
  self.db = None
 
810
 
811
  async def create_vector_index_if_not_exists(self):
812
  """Creates an Atlas Vector Search index."""
 
1
  import os
2
+ from dataclasses import dataclass, field
3
  import numpy as np
4
  import configparser
5
  from tqdm.asyncio import tqdm as tqdm_async
 
27
  pm.install("motor")
28
 
29
  try:
30
+ from motor.motor_asyncio import (
31
+ AsyncIOMotorClient,
32
+ AsyncIOMotorDatabase,
33
+ AsyncIOMotorCollection,
34
+ )
35
  from pymongo.operations import SearchIndexModel
36
  from pymongo.errors import PyMongoError
37
  except ImportError as e:
 
83
  @final
84
  @dataclass
85
  class MongoKVStorage(BaseKVStorage):
86
+ db: AsyncIOMotorDatabase = field(init=False)
87
+ _data: AsyncIOMotorCollection = field(init=False)
88
+
89
  def __post_init__(self):
90
  self._collection_name = self.namespace
91
 
92
  async def initialize(self):
93
+ if self.db is None:
94
  self.db = await ClientManager.get_client()
95
  self._data = await get_or_create_collection(self.db, self._collection_name)
96
  logger.debug(f"Use MongoDB as KV {self._collection_name}")
97
 
98
  async def finalize(self):
99
+ if self.db is not None:
100
  await ClientManager.release_client(self.db)
101
  self.db = None
102
+ self._data = None
103
 
104
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
105
  return await self._data.find_one({"_id": id})
 
156
  @final
157
  @dataclass
158
  class MongoDocStatusStorage(DocStatusStorage):
159
+ db: AsyncIOMotorDatabase = field(init=False)
160
+ _data: AsyncIOMotorCollection = field(init=False)
161
+
162
  def __post_init__(self):
163
  self._collection_name = self.namespace
164
 
165
  async def initialize(self):
166
+ if self.db is None:
167
  self.db = await ClientManager.get_client()
168
  self._data = await get_or_create_collection(self.db, self._collection_name)
169
  logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
170
 
171
  async def finalize(self):
172
+ if self.db is not None:
173
  await ClientManager.release_client(self.db)
174
  self.db = None
175
+ self._data = None
176
 
177
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
178
  return await self._data.find_one({"_id": id})
 
233
  @dataclass
234
  class MongoGraphStorage(BaseGraphStorage):
235
  """
236
+ A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
237
  """
238
 
239
+ db: AsyncIOMotorDatabase = field(init=False)
240
+ collection: AsyncIOMotorCollection = field(init=False)
241
+
242
  def __init__(self, namespace, global_config, embedding_func):
243
  super().__init__(
244
  namespace=namespace,
 
248
  self._collection_name = self.namespace
249
 
250
  async def initialize(self):
251
+ if self.db is None:
252
  self.db = await ClientManager.get_client()
253
  self.collection = await get_or_create_collection(
254
  self.db, self._collection_name
 
256
  logger.debug(f"Use MongoDB as KG {self._collection_name}")
257
 
258
  async def finalize(self):
259
+ if self.db is not None:
260
  await ClientManager.release_client(self.db)
261
  self.db = None
262
+ self.collection = None
263
 
264
  #
265
  # -------------------------------------------------------------------------
 
798
  @final
799
  @dataclass
800
  class MongoVectorDBStorage(BaseVectorStorage):
801
+ db: AsyncIOMotorDatabase = field(init=False)
802
+ _data: AsyncIOMotorCollection = field(init=False)
803
+
804
  def __post_init__(self):
805
  kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
806
  cosine_threshold = kwargs.get("cosine_better_than_threshold")
 
813
  self._max_batch_size = self.global_config["embedding_batch_num"]
814
 
815
  async def initialize(self):
816
+ if self.db is None:
817
  self.db = await ClientManager.get_client()
818
  self._data = await get_or_create_collection(self.db, self._collection_name)
819
 
 
823
  logger.debug(f"Use MongoDB as VDB {self._collection_name}")
824
 
825
  async def finalize(self):
826
+ if self.db is not None:
827
  await ClientManager.release_client(self.db)
828
  self.db = None
829
+ self._data = None
830
 
831
  async def create_vector_index_if_not_exists(self):
832
  """Creates an Atlas Vector Search index."""
lightrag/kg/oracle_impl.py CHANGED
@@ -3,7 +3,7 @@ import asyncio
3
 
4
  # import html
5
  import os
6
- from dataclasses import dataclass
7
  from typing import Any, Union, final
8
  import numpy as np
9
  import configparser
@@ -242,8 +242,7 @@ class ClientManager:
242
  @final
243
  @dataclass
244
  class OracleKVStorage(BaseKVStorage):
245
- # db instance must be injected before use
246
- # db: OracleDB
247
  meta_fields = None
248
 
249
  def __post_init__(self):
@@ -251,11 +250,11 @@ class OracleKVStorage(BaseKVStorage):
251
  self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
252
 
253
  async def initialize(self):
254
- if not hasattr(self, "db") or self.db is None:
255
  self.db = await ClientManager.get_client()
256
 
257
  async def finalize(self):
258
- if hasattr(self, "db") and self.db is not None:
259
  await ClientManager.release_client(self.db)
260
  self.db = None
261
 
@@ -395,6 +394,8 @@ class OracleKVStorage(BaseKVStorage):
395
  @final
396
  @dataclass
397
  class OracleVectorDBStorage(BaseVectorStorage):
 
 
398
  def __post_init__(self):
399
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
400
  cosine_threshold = config.get("cosine_better_than_threshold")
@@ -405,11 +406,11 @@ class OracleVectorDBStorage(BaseVectorStorage):
405
  self.cosine_better_than_threshold = cosine_threshold
406
 
407
  async def initialize(self):
408
- if not hasattr(self, "db") or self.db is None:
409
  self.db = await ClientManager.get_client()
410
 
411
  async def finalize(self):
412
- if hasattr(self, "db") and self.db is not None:
413
  await ClientManager.release_client(self.db)
414
  self.db = None
415
 
@@ -449,15 +450,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
449
  @final
450
  @dataclass
451
  class OracleGraphStorage(BaseGraphStorage):
 
 
452
  def __post_init__(self):
453
  self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
454
 
455
  async def initialize(self):
456
- if not hasattr(self, "db") or self.db is None:
457
  self.db = await ClientManager.get_client()
458
 
459
  async def finalize(self):
460
- if hasattr(self, "db") and self.db is not None:
461
  await ClientManager.release_client(self.db)
462
  self.db = None
463
 
 
3
 
4
  # import html
5
  import os
6
+ from dataclasses import dataclass, field
7
  from typing import Any, Union, final
8
  import numpy as np
9
  import configparser
 
242
  @final
243
  @dataclass
244
  class OracleKVStorage(BaseKVStorage):
245
+ db: OracleDB = field(init=False)
 
246
  meta_fields = None
247
 
248
  def __post_init__(self):
 
250
  self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
251
 
252
  async def initialize(self):
253
+ if self.db is None:
254
  self.db = await ClientManager.get_client()
255
 
256
  async def finalize(self):
257
+ if self.db is not None:
258
  await ClientManager.release_client(self.db)
259
  self.db = None
260
 
 
394
  @final
395
  @dataclass
396
  class OracleVectorDBStorage(BaseVectorStorage):
397
+ db: OracleDB = field(init=False)
398
+
399
  def __post_init__(self):
400
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
401
  cosine_threshold = config.get("cosine_better_than_threshold")
 
406
  self.cosine_better_than_threshold = cosine_threshold
407
 
408
  async def initialize(self):
409
+ if self.db is None:
410
  self.db = await ClientManager.get_client()
411
 
412
  async def finalize(self):
413
+ if self.db is not None:
414
  await ClientManager.release_client(self.db)
415
  self.db = None
416
 
 
450
  @final
451
  @dataclass
452
  class OracleGraphStorage(BaseGraphStorage):
453
+ db: OracleDB = field(init=False)
454
+
455
  def __post_init__(self):
456
  self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
457
 
458
  async def initialize(self):
459
+ if self.db is None:
460
  self.db = await ClientManager.get_client()
461
 
462
  async def finalize(self):
463
+ if self.db is not None:
464
  await ClientManager.release_client(self.db)
465
  self.db = None
466
 
lightrag/kg/postgres_impl.py CHANGED
@@ -3,7 +3,7 @@ import inspect
3
  import json
4
  import os
5
  import time
6
- from dataclasses import dataclass
7
  from typing import Any, Dict, List, Union, final
8
  import numpy as np
9
  import configparser
@@ -246,18 +246,17 @@ class ClientManager:
246
  @final
247
  @dataclass
248
  class PGKVStorage(BaseKVStorage):
249
- # db instance must be injected before use
250
- # db: PostgreSQLDB
251
 
252
  def __post_init__(self):
253
  self._max_batch_size = self.global_config["embedding_batch_num"]
254
 
255
  async def initialize(self):
256
- if not hasattr(self, "db") or self.db is None:
257
  self.db = await ClientManager.get_client()
258
 
259
  async def finalize(self):
260
- if hasattr(self, "db") and self.db is not None:
261
  await ClientManager.release_client(self.db)
262
  self.db = None
263
 
@@ -379,6 +378,8 @@ class PGKVStorage(BaseKVStorage):
379
  @final
380
  @dataclass
381
  class PGVectorStorage(BaseVectorStorage):
 
 
382
  def __post_init__(self):
383
  self._max_batch_size = self.global_config["embedding_batch_num"]
384
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -390,11 +391,11 @@ class PGVectorStorage(BaseVectorStorage):
390
  self.cosine_better_than_threshold = cosine_threshold
391
 
392
  async def initialize(self):
393
- if not hasattr(self, "db") or self.db is None:
394
  self.db = await ClientManager.get_client()
395
 
396
  async def finalize(self):
397
- if hasattr(self, "db") and self.db is not None:
398
  await ClientManager.release_client(self.db)
399
  self.db = None
400
 
@@ -514,12 +515,14 @@ class PGVectorStorage(BaseVectorStorage):
514
  @final
515
  @dataclass
516
  class PGDocStatusStorage(DocStatusStorage):
 
 
517
  async def initialize(self):
518
- if not hasattr(self, "db") or self.db is None:
519
  self.db = await ClientManager.get_client()
520
 
521
  async def finalize(self):
522
- if hasattr(self, "db") and self.db is not None:
523
  await ClientManager.release_client(self.db)
524
  self.db = None
525
 
@@ -662,6 +665,8 @@ class PGGraphQueryException(Exception):
662
  @final
663
  @dataclass
664
  class PGGraphStorage(BaseGraphStorage):
 
 
665
  @staticmethod
666
  def load_nx_graph(file_name):
667
  print("no preloading of graph with AGE in production")
@@ -673,11 +678,11 @@ class PGGraphStorage(BaseGraphStorage):
673
  }
674
 
675
  async def initialize(self):
676
- if not hasattr(self, "db") or self.db is None:
677
  self.db = await ClientManager.get_client()
678
 
679
  async def finalize(self):
680
- if hasattr(self, "db") and self.db is not None:
681
  await ClientManager.release_client(self.db)
682
  self.db = None
683
 
 
3
  import json
4
  import os
5
  import time
6
+ from dataclasses import dataclass, field
7
  from typing import Any, Dict, List, Union, final
8
  import numpy as np
9
  import configparser
 
246
  @final
247
  @dataclass
248
  class PGKVStorage(BaseKVStorage):
249
+ db: PostgreSQLDB = field(init=False)
 
250
 
251
  def __post_init__(self):
252
  self._max_batch_size = self.global_config["embedding_batch_num"]
253
 
254
  async def initialize(self):
255
+ if self.db is None:
256
  self.db = await ClientManager.get_client()
257
 
258
  async def finalize(self):
259
+ if self.db is not None:
260
  await ClientManager.release_client(self.db)
261
  self.db = None
262
 
 
378
  @final
379
  @dataclass
380
  class PGVectorStorage(BaseVectorStorage):
381
+ db: PostgreSQLDB = field(init=False)
382
+
383
  def __post_init__(self):
384
  self._max_batch_size = self.global_config["embedding_batch_num"]
385
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
 
391
  self.cosine_better_than_threshold = cosine_threshold
392
 
393
  async def initialize(self):
394
+ if self.db is None:
395
  self.db = await ClientManager.get_client()
396
 
397
  async def finalize(self):
398
+ if self.db is not None:
399
  await ClientManager.release_client(self.db)
400
  self.db = None
401
 
 
515
  @final
516
  @dataclass
517
  class PGDocStatusStorage(DocStatusStorage):
518
+ db: PostgreSQLDB = field(init=False)
519
+
520
  async def initialize(self):
521
+ if self.db is None:
522
  self.db = await ClientManager.get_client()
523
 
524
  async def finalize(self):
525
+ if self.db is not None:
526
  await ClientManager.release_client(self.db)
527
  self.db = None
528
 
 
665
  @final
666
  @dataclass
667
  class PGGraphStorage(BaseGraphStorage):
668
+ db: PostgreSQLDB = field(init=False)
669
+
670
  @staticmethod
671
  def load_nx_graph(file_name):
672
  print("no preloading of graph with AGE in production")
 
678
  }
679
 
680
  async def initialize(self):
681
+ if self.db is None:
682
  self.db = await ClientManager.get_client()
683
 
684
  async def finalize(self):
685
+ if self.db is not None:
686
  await ClientManager.release_client(self.db)
687
  self.db = None
688
 
lightrag/kg/tidb_impl.py CHANGED
@@ -1,6 +1,6 @@
1
  import asyncio
2
  import os
3
- from dataclasses import dataclass
4
  from typing import Any, Union, final
5
 
6
  import numpy as np
@@ -166,19 +166,18 @@ class ClientManager:
166
  @final
167
  @dataclass
168
  class TiDBKVStorage(BaseKVStorage):
169
- # db instance must be injected before use
170
- # db: TiDB
171
 
172
  def __post_init__(self):
173
  self._data = {}
174
  self._max_batch_size = self.global_config["embedding_batch_num"]
175
 
176
  async def initialize(self):
177
- if not hasattr(self, "db") or self.db is None:
178
  self.db = await ClientManager.get_client()
179
 
180
  async def finalize(self):
181
- if hasattr(self, "db") and self.db is not None:
182
  await ClientManager.release_client(self.db)
183
  self.db = None
184
 
@@ -280,6 +279,8 @@ class TiDBKVStorage(BaseKVStorage):
280
  @final
281
  @dataclass
282
  class TiDBVectorDBStorage(BaseVectorStorage):
 
 
283
  def __post_init__(self):
284
  self._client_file_name = os.path.join(
285
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
@@ -294,11 +295,11 @@ class TiDBVectorDBStorage(BaseVectorStorage):
294
  self.cosine_better_than_threshold = cosine_threshold
295
 
296
  async def initialize(self):
297
- if not hasattr(self, "db") or self.db is None:
298
  self.db = await ClientManager.get_client()
299
 
300
  async def finalize(self):
301
- if hasattr(self, "db") and self.db is not None:
302
  await ClientManager.release_client(self.db)
303
  self.db = None
304
 
@@ -421,18 +422,17 @@ class TiDBVectorDBStorage(BaseVectorStorage):
421
  @final
422
  @dataclass
423
  class TiDBGraphStorage(BaseGraphStorage):
424
- # db instance must be injected before use
425
- # db: TiDB
426
 
427
  def __post_init__(self):
428
  self._max_batch_size = self.global_config["embedding_batch_num"]
429
 
430
  async def initialize(self):
431
- if not hasattr(self, "db") or self.db is None:
432
  self.db = await ClientManager.get_client()
433
 
434
  async def finalize(self):
435
- if hasattr(self, "db") and self.db is not None:
436
  await ClientManager.release_client(self.db)
437
  self.db = None
438
 
 
1
  import asyncio
2
  import os
3
+ from dataclasses import dataclass, field
4
  from typing import Any, Union, final
5
 
6
  import numpy as np
 
166
  @final
167
  @dataclass
168
  class TiDBKVStorage(BaseKVStorage):
169
+ db: TiDB = field(init=False)
 
170
 
171
  def __post_init__(self):
172
  self._data = {}
173
  self._max_batch_size = self.global_config["embedding_batch_num"]
174
 
175
  async def initialize(self):
176
+ if self.db is None:
177
  self.db = await ClientManager.get_client()
178
 
179
  async def finalize(self):
180
+ if self.db is not None:
181
  await ClientManager.release_client(self.db)
182
  self.db = None
183
 
 
279
  @final
280
  @dataclass
281
  class TiDBVectorDBStorage(BaseVectorStorage):
282
+ db: TiDB = field(init=False)
283
+
284
  def __post_init__(self):
285
  self._client_file_name = os.path.join(
286
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
 
295
  self.cosine_better_than_threshold = cosine_threshold
296
 
297
  async def initialize(self):
298
+ if self.db is None:
299
  self.db = await ClientManager.get_client()
300
 
301
  async def finalize(self):
302
+ if self.db is not None:
303
  await ClientManager.release_client(self.db)
304
  self.db = None
305
 
 
422
  @final
423
  @dataclass
424
  class TiDBGraphStorage(BaseGraphStorage):
425
+ db: TiDB = field(init=False)
 
426
 
427
  def __post_init__(self):
428
  self._max_batch_size = self.global_config["embedding_batch_num"]
429
 
430
  async def initialize(self):
431
+ if self.db is None:
432
  self.db = await ClientManager.get_client()
433
 
434
  async def finalize(self):
435
+ if self.db is not None:
436
  await ClientManager.release_client(self.db)
437
  self.db = None
438