ArnoChen commited on
Commit
d0b5505
·
1 Parent(s): 858f35e

implement MongoDB support for VectorDB storage. optimize existing MongoDB implementations

Browse files
lightrag/api/README.md CHANGED
@@ -177,7 +177,8 @@ TiDBVectorDBStorage TiDB
177
  PGVectorStorage Postgres
178
  FaissVectorDBStorage Faiss
179
  QdrantVectorDBStorage Qdrant
180
- OracleVectorDBStorag Oracle
 
181
  ```
182
 
183
  * DOC_STATUS_STORAGE:supported implement-name
 
177
  PGVectorStorage Postgres
178
  FaissVectorDBStorage Faiss
179
  QdrantVectorDBStorage Qdrant
180
+ OracleVectorDBStorage Oracle
181
+ MongoVectorDBStorage MongoDB
182
  ```
183
 
184
  * DOC_STATUS_STORAGE:supported implement-name
lightrag/kg/mongo_impl.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import pipmaster as pm
5
  import configparser
6
  from tqdm.asyncio import tqdm as tqdm_async
 
7
 
8
  if not pm.is_installed("pymongo"):
9
  pm.install("pymongo")
@@ -14,16 +15,20 @@ if not pm.is_installed("motor"):
14
  from typing import Any, List, Tuple, Union
15
  from motor.motor_asyncio import AsyncIOMotorClient
16
  from pymongo import MongoClient
 
 
17
 
18
  from ..base import (
19
  BaseGraphStorage,
20
  BaseKVStorage,
 
21
  DocProcessingStatus,
22
  DocStatus,
23
  DocStatusStorage,
24
  )
25
  from ..namespace import NameSpace, is_namespace
26
  from ..utils import logger
 
27
 
28
 
29
  config = configparser.ConfigParser()
@@ -33,56 +38,66 @@ config.read("config.ini", "utf-8")
33
  @dataclass
34
  class MongoKVStorage(BaseKVStorage):
35
  def __post_init__(self):
36
- client = MongoClient(
37
- os.environ.get(
38
- "MONGO_URI",
39
- config.get(
40
- "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
41
- ),
42
- )
43
  )
 
44
  database = client.get_database(
45
  os.environ.get(
46
  "MONGO_DATABASE",
47
  config.get("mongodb", "database", fallback="LightRAG"),
48
  )
49
  )
50
- self._data = database.get_collection(self.namespace)
51
- logger.info(f"Use MongoDB as KV {self.namespace}")
 
 
 
 
 
 
52
 
53
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
54
- return self._data.find_one({"_id": id})
55
 
56
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
57
- return list(self._data.find({"_id": {"$in": ids}}))
 
58
 
59
  async def filter_keys(self, data: set[str]) -> set[str]:
60
- existing_ids = [
61
- str(x["_id"])
62
- for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
63
- ]
64
- return set([s for s in data if s not in existing_ids])
65
 
66
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
67
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
 
68
  for mode, items in data.items():
69
- for k, v in tqdm_async(items.items(), desc="Upserting"):
70
  key = f"{mode}_{k}"
71
- result = self._data.update_one(
72
- {"_id": key}, {"$setOnInsert": v}, upsert=True
 
 
 
73
  )
74
- if result.upserted_id:
75
- logger.debug(f"\nInserted new document with key: {key}")
76
- data[mode][k]["_id"] = key
77
  else:
78
- for k, v in tqdm_async(data.items(), desc="Upserting"):
79
- self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
80
  data[k]["_id"] = k
 
 
 
 
81
 
82
  async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
83
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
84
  res = {}
85
- v = self._data.find_one({"_id": mode + "_" + id})
86
  if v:
87
  res[id] = v
88
  logger.debug(f"llm_response_cache find one by:{id}")
@@ -100,30 +115,48 @@ class MongoKVStorage(BaseKVStorage):
100
  @dataclass
101
  class MongoDocStatusStorage(DocStatusStorage):
102
  def __post_init__(self):
103
- client = MongoClient(
104
- os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/")
 
 
 
 
 
 
 
 
 
 
105
  )
106
- database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG"))
107
- self._data = database.get_collection(self.namespace)
108
- logger.info(f"Use MongoDB as doc status {self.namespace}")
 
 
 
 
 
109
 
110
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
111
- return self._data.find_one({"_id": id})
112
 
113
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
114
- return list(self._data.find({"_id": {"$in": ids}}))
 
115
 
116
  async def filter_keys(self, data: set[str]) -> set[str]:
117
- existing_ids = [
118
- str(x["_id"])
119
- for x in self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
120
- ]
121
- return set([s for s in data if s not in existing_ids])
122
 
123
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
124
  for k, v in data.items():
125
- self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
126
  data[k]["_id"] = k
 
 
 
 
127
 
128
  async def drop(self) -> None:
129
  """Drop the collection"""
@@ -132,7 +165,8 @@ class MongoDocStatusStorage(DocStatusStorage):
132
  async def get_status_counts(self) -> dict[str, int]:
133
  """Get counts of documents in each status"""
134
  pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
135
- result = list(self._data.aggregate(pipeline))
 
136
  counts = {}
137
  for doc in result:
138
  counts[doc["_id"]] = doc["count"]
@@ -142,7 +176,8 @@ class MongoDocStatusStorage(DocStatusStorage):
142
  self, status: DocStatus
143
  ) -> dict[str, DocProcessingStatus]:
144
  """Get all documents by status"""
145
- result = list(self._data.find({"status": status.value}))
 
146
  return {
147
  doc["_id"]: DocProcessingStatus(
148
  content=doc["content"],
@@ -185,26 +220,27 @@ class MongoGraphStorage(BaseGraphStorage):
185
  global_config=global_config,
186
  embedding_func=embedding_func,
187
  )
188
- self.client = AsyncIOMotorClient(
189
- os.environ.get(
190
- "MONGO_URI",
191
- config.get(
192
- "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
193
- ),
194
- )
195
  )
196
- self.db = self.client[
 
197
  os.environ.get(
198
  "MONGO_DATABASE",
199
- mongo_database=config.get("mongodb", "database", fallback="LightRAG"),
200
- )
201
- ]
202
- self.collection = self.db[
203
- os.environ.get(
204
- "MONGO_KG_COLLECTION",
205
- config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"),
206
  )
207
- ]
 
 
 
 
 
 
 
 
208
 
209
  #
210
  # -------------------------------------------------------------------------
@@ -451,7 +487,7 @@ class MongoGraphStorage(BaseGraphStorage):
451
  self, source_node_id: str
452
  ) -> Union[List[Tuple[str, str]], None]:
453
  """
454
- Return a list of (target_id, relation) for direct edges from source_node_id.
455
  Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
456
  """
457
  pipeline = [
@@ -475,7 +511,7 @@ class MongoGraphStorage(BaseGraphStorage):
475
  return None
476
 
477
  edges = result[0].get("edges", [])
478
- return [(e["target"], e["relation"]) for e in edges]
479
 
480
  #
481
  # -------------------------------------------------------------------------
@@ -522,7 +558,7 @@ class MongoGraphStorage(BaseGraphStorage):
522
 
523
  async def delete_node(self, node_id: str):
524
  """
525
- 1) Remove nodes doc entirely.
526
  2) Remove inbound edges from any doc that references node_id.
527
  """
528
  # Remove inbound edges from all other docs
@@ -542,3 +578,359 @@ class MongoGraphStorage(BaseGraphStorage):
542
  Placeholder for demonstration, raises NotImplementedError.
543
  """
544
  raise NotImplementedError("Node embedding is not used in lightrag.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import pipmaster as pm
5
  import configparser
6
  from tqdm.asyncio import tqdm as tqdm_async
7
+ import asyncio
8
 
9
  if not pm.is_installed("pymongo"):
10
  pm.install("pymongo")
 
15
  from typing import Any, List, Tuple, Union
16
  from motor.motor_asyncio import AsyncIOMotorClient
17
  from pymongo import MongoClient
18
+ from pymongo.operations import SearchIndexModel
19
+ from pymongo.errors import PyMongoError
20
 
21
  from ..base import (
22
  BaseGraphStorage,
23
  BaseKVStorage,
24
+ BaseVectorStorage,
25
  DocProcessingStatus,
26
  DocStatus,
27
  DocStatusStorage,
28
  )
29
  from ..namespace import NameSpace, is_namespace
30
  from ..utils import logger
31
+ from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
32
 
33
 
34
  config = configparser.ConfigParser()
 
38
  @dataclass
39
  class MongoKVStorage(BaseKVStorage):
40
  def __post_init__(self):
41
+ uri = os.environ.get(
42
+ "MONGO_URI",
43
+ config.get(
44
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
45
+ ),
 
 
46
  )
47
+ client = AsyncIOMotorClient(uri)
48
  database = client.get_database(
49
  os.environ.get(
50
  "MONGO_DATABASE",
51
  config.get("mongodb", "database", fallback="LightRAG"),
52
  )
53
  )
54
+
55
+ self._collection_name = self.namespace
56
+
57
+ self._data = database.get_collection(self._collection_name)
58
+ logger.debug(f"Use MongoDB as KV {self._collection_name}")
59
+
60
+ # Ensure collection exists
61
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
62
 
63
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
64
+ return await self._data.find_one({"_id": id})
65
 
66
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
67
+ cursor = self._data.find({"_id": {"$in": ids}})
68
+ return await cursor.to_list()
69
 
70
  async def filter_keys(self, data: set[str]) -> set[str]:
71
+ cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
72
+ existing_ids = {str(x["_id"]) async for x in cursor}
73
+ return data - existing_ids
 
 
74
 
75
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
76
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
77
+ update_tasks = []
78
  for mode, items in data.items():
79
+ for k, v in items.items():
80
  key = f"{mode}_{k}"
81
+ data[mode][k]["_id"] = f"{mode}_{k}"
82
+ update_tasks.append(
83
+ self._data.update_one(
84
+ {"_id": key}, {"$setOnInsert": v}, upsert=True
85
+ )
86
  )
87
+ await asyncio.gather(*update_tasks)
 
 
88
  else:
89
+ update_tasks = []
90
+ for k, v in data.items():
91
  data[k]["_id"] = k
92
+ update_tasks.append(
93
+ self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
94
+ )
95
+ await asyncio.gather(*update_tasks)
96
 
97
  async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
98
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
99
  res = {}
100
+ v = await self._data.find_one({"_id": mode + "_" + id})
101
  if v:
102
  res[id] = v
103
  logger.debug(f"llm_response_cache find one by:{id}")
 
115
  @dataclass
116
  class MongoDocStatusStorage(DocStatusStorage):
117
  def __post_init__(self):
118
+ uri = os.environ.get(
119
+ "MONGO_URI",
120
+ config.get(
121
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
122
+ ),
123
+ )
124
+ client = AsyncIOMotorClient(uri)
125
+ database = client.get_database(
126
+ os.environ.get(
127
+ "MONGO_DATABASE",
128
+ config.get("mongodb", "database", fallback="LightRAG"),
129
+ )
130
  )
131
+
132
+ self._collection_name = self.namespace
133
+ self._data = database.get_collection(self._collection_name)
134
+
135
+ logger.debug(f"Use MongoDB as doc status {self._collection_name}")
136
+
137
+ # Ensure collection exists
138
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
139
 
140
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
141
+ return await self._data.find_one({"_id": id})
142
 
143
  async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
144
+ cursor = self._data.find({"_id": {"$in": ids}})
145
+ return await cursor.to_list()
146
 
147
  async def filter_keys(self, data: set[str]) -> set[str]:
148
+ cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1})
149
+ existing_ids = {str(x["_id"]) async for x in cursor}
150
+ return data - existing_ids
 
 
151
 
152
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
153
+ update_tasks = []
154
  for k, v in data.items():
 
155
  data[k]["_id"] = k
156
+ update_tasks.append(
157
+ self._data.update_one({"_id": k}, {"$set": v}, upsert=True)
158
+ )
159
+ await asyncio.gather(*update_tasks)
160
 
161
  async def drop(self) -> None:
162
  """Drop the collection"""
 
165
  async def get_status_counts(self) -> dict[str, int]:
166
  """Get counts of documents in each status"""
167
  pipeline = [{"$group": {"_id": "$status", "count": {"$sum": 1}}}]
168
+ cursor = self._data.aggregate(pipeline)
169
+ result = await cursor.to_list()
170
  counts = {}
171
  for doc in result:
172
  counts[doc["_id"]] = doc["count"]
 
176
  self, status: DocStatus
177
  ) -> dict[str, DocProcessingStatus]:
178
  """Get all documents by status"""
179
+ cursor = self._data.find({"status": status.value})
180
+ result = await cursor.to_list()
181
  return {
182
  doc["_id"]: DocProcessingStatus(
183
  content=doc["content"],
 
220
  global_config=global_config,
221
  embedding_func=embedding_func,
222
  )
223
+ uri = os.environ.get(
224
+ "MONGO_URI",
225
+ config.get(
226
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
227
+ ),
 
 
228
  )
229
+ client = AsyncIOMotorClient(uri)
230
+ database = client.get_database(
231
  os.environ.get(
232
  "MONGO_DATABASE",
233
+ config.get("mongodb", "database", fallback="LightRAG"),
 
 
 
 
 
 
234
  )
235
+ )
236
+
237
+ self._collection_name = self.namespace
238
+ self.collection = database.get_collection(self._collection_name)
239
+
240
+ logger.debug(f"Use MongoDB as KG {self._collection_name}")
241
+
242
+ # Ensure collection exists
243
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
244
 
245
  #
246
  # -------------------------------------------------------------------------
 
487
  self, source_node_id: str
488
  ) -> Union[List[Tuple[str, str]], None]:
489
  """
490
+ Return a list of (source_id, target_id) for direct edges from source_node_id.
491
  Demonstrates $graphLookup at maxDepth=0, though direct doc retrieval is simpler.
492
  """
493
  pipeline = [
 
511
  return None
512
 
513
  edges = result[0].get("edges", [])
514
+ return [(source_node_id, e["target"]) for e in edges]
515
 
516
  #
517
  # -------------------------------------------------------------------------
 
558
 
559
  async def delete_node(self, node_id: str):
560
  """
561
+ 1) Remove node's doc entirely.
562
  2) Remove inbound edges from any doc that references node_id.
563
  """
564
  # Remove inbound edges from all other docs
 
578
  Placeholder for demonstration, raises NotImplementedError.
579
  """
580
  raise NotImplementedError("Node embedding is not used in lightrag.")
581
+
582
+ #
583
+ # -------------------------------------------------------------------------
584
+ # QUERY
585
+ # -------------------------------------------------------------------------
586
+ #
587
+
588
+ async def get_all_labels(self) -> list[str]:
589
+ """
590
+ Get all existing node _id in the database
591
+ Returns:
592
+ [id1, id2, ...] # Alphabetically sorted id list
593
+ """
594
+ # Use MongoDB's distinct and aggregation to get all unique labels
595
+ pipeline = [
596
+ {"$group": {"_id": "$_id"}}, # Group by _id
597
+ {"$sort": {"_id": 1}}, # Sort alphabetically
598
+ ]
599
+
600
+ cursor = self.collection.aggregate(pipeline)
601
+ labels = []
602
+ async for doc in cursor:
603
+ labels.append(doc["_id"])
604
+ return labels
605
+
606
+ async def get_knowledge_graph(
607
+ self, node_label: str, max_depth: int = 5
608
+ ) -> KnowledgeGraph:
609
+ """
610
+ Get complete connected subgraph for specified node (including the starting node itself)
611
+
612
+ Args:
613
+ node_label: Label of the nodes to start from
614
+ max_depth: Maximum depth of traversal (default: 5)
615
+
616
+ Returns:
617
+ KnowledgeGraph object containing nodes and edges of the subgraph
618
+ """
619
+ label = node_label
620
+ result = KnowledgeGraph()
621
+ seen_nodes = set()
622
+ seen_edges = set()
623
+
624
+ try:
625
+ if label == "*":
626
+ # Get all nodes and edges
627
+ async for node_doc in self.collection.find({}):
628
+ node_id = str(node_doc["_id"])
629
+ if node_id not in seen_nodes:
630
+ result.nodes.append(
631
+ KnowledgeGraphNode(
632
+ id=node_id,
633
+ labels=[node_doc.get("_id")],
634
+ properties={
635
+ k: v
636
+ for k, v in node_doc.items()
637
+ if k not in ["_id", "edges"]
638
+ },
639
+ )
640
+ )
641
+ seen_nodes.add(node_id)
642
+
643
+ # Process edges
644
+ for edge in node_doc.get("edges", []):
645
+ edge_id = f"{node_id}-{edge['target']}"
646
+ if edge_id not in seen_edges:
647
+ result.edges.append(
648
+ KnowledgeGraphEdge(
649
+ id=edge_id,
650
+ type=edge.get("relation", ""),
651
+ source=node_id,
652
+ target=edge["target"],
653
+ properties={
654
+ k: v
655
+ for k, v in edge.items()
656
+ if k not in ["target", "relation"]
657
+ },
658
+ )
659
+ )
660
+ seen_edges.add(edge_id)
661
+ else:
662
+ # Verify if starting node exists
663
+ start_nodes = self.collection.find({"_id": label})
664
+ start_nodes_exist = await start_nodes.to_list(length=1)
665
+ if not start_nodes_exist:
666
+ logger.warning(f"Starting node with label {label} does not exist!")
667
+ return result
668
+
669
+ # Use $graphLookup for traversal
670
+ pipeline = [
671
+ {
672
+ "$match": {"_id": label}
673
+ }, # Start with nodes having the specified label
674
+ {
675
+ "$graphLookup": {
676
+ "from": self._collection_name,
677
+ "startWith": "$edges.target",
678
+ "connectFromField": "edges.target",
679
+ "connectToField": "_id",
680
+ "maxDepth": max_depth,
681
+ "depthField": "depth",
682
+ "as": "connected_nodes",
683
+ }
684
+ },
685
+ ]
686
+
687
+ async for doc in self.collection.aggregate(pipeline):
688
+ # Add the start node
689
+ node_id = str(doc["_id"])
690
+ if node_id not in seen_nodes:
691
+ result.nodes.append(
692
+ KnowledgeGraphNode(
693
+ id=node_id,
694
+ labels=[
695
+ doc.get(
696
+ "_id",
697
+ )
698
+ ],
699
+ properties={
700
+ k: v
701
+ for k, v in doc.items()
702
+ if k
703
+ not in [
704
+ "_id",
705
+ "edges",
706
+ "connected_nodes",
707
+ "depth",
708
+ ]
709
+ },
710
+ )
711
+ )
712
+ seen_nodes.add(node_id)
713
+
714
+ # Add edges from start node
715
+ for edge in doc.get("edges", []):
716
+ edge_id = f"{node_id}-{edge['target']}"
717
+ if edge_id not in seen_edges:
718
+ result.edges.append(
719
+ KnowledgeGraphEdge(
720
+ id=edge_id,
721
+ type=edge.get("relation", ""),
722
+ source=node_id,
723
+ target=edge["target"],
724
+ properties={
725
+ k: v
726
+ for k, v in edge.items()
727
+ if k not in ["target", "relation"]
728
+ },
729
+ )
730
+ )
731
+ seen_edges.add(edge_id)
732
+
733
+ # Add connected nodes and their edges
734
+ for connected in doc.get("connected_nodes", []):
735
+ node_id = str(connected["_id"])
736
+ if node_id not in seen_nodes:
737
+ result.nodes.append(
738
+ KnowledgeGraphNode(
739
+ id=node_id,
740
+ labels=[connected.get("_id")],
741
+ properties={
742
+ k: v
743
+ for k, v in connected.items()
744
+ if k not in ["_id", "edges", "depth"]
745
+ },
746
+ )
747
+ )
748
+ seen_nodes.add(node_id)
749
+
750
+ # Add edges from connected nodes
751
+ for edge in connected.get("edges", []):
752
+ edge_id = f"{node_id}-{edge['target']}"
753
+ if edge_id not in seen_edges:
754
+ result.edges.append(
755
+ KnowledgeGraphEdge(
756
+ id=edge_id,
757
+ type=edge.get("relation", ""),
758
+ source=node_id,
759
+ target=edge["target"],
760
+ properties={
761
+ k: v
762
+ for k, v in edge.items()
763
+ if k not in ["target", "relation"]
764
+ },
765
+ )
766
+ )
767
+ seen_edges.add(edge_id)
768
+
769
+ logger.info(
770
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
771
+ )
772
+
773
+ except PyMongoError as e:
774
+ logger.error(f"MongoDB query failed: {str(e)}")
775
+
776
+ return result
777
+
778
+
779
+ @dataclass
780
+ class MongoVectorDBStorage(BaseVectorStorage):
781
+ cosine_better_than_threshold: float = None
782
+
783
+ def __post_init__(self):
784
+ kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
785
+ cosine_threshold = kwargs.get("cosine_better_than_threshold")
786
+ if cosine_threshold is None:
787
+ raise ValueError(
788
+ "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
789
+ )
790
+ self.cosine_better_than_threshold = cosine_threshold
791
+
792
+ uri = os.environ.get(
793
+ "MONGO_URI",
794
+ config.get(
795
+ "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
796
+ ),
797
+ )
798
+ client = AsyncIOMotorClient(uri)
799
+ database = client.get_database(
800
+ os.environ.get(
801
+ "MONGO_DATABASE",
802
+ config.get("mongodb", "database", fallback="LightRAG"),
803
+ )
804
+ )
805
+
806
+ self._collection_name = self.namespace
807
+ self._data = database.get_collection(self._collection_name)
808
+ self._max_batch_size = self.global_config["embedding_batch_num"]
809
+
810
+ logger.debug(f"Use MongoDB as VDB {self._collection_name}")
811
+
812
+ # Ensure collection exists
813
+ create_collection_if_not_exists(uri, database.name, self._collection_name)
814
+
815
+ # Ensure vector index exists
816
+ self.create_vector_index(uri, database.name, self._collection_name)
817
+
818
+ def create_vector_index(self, uri: str, database_name: str, collection_name: str):
819
+ """Creates an Atlas Vector Search index."""
820
+ client = MongoClient(uri)
821
+ collection = client.get_database(database_name).get_collection(
822
+ self._collection_name
823
+ )
824
+
825
+ try:
826
+ search_index_model = SearchIndexModel(
827
+ definition={
828
+ "fields": [
829
+ {
830
+ "type": "vector",
831
+ "numDimensions": self.embedding_func.embedding_dim, # Ensure correct dimensions
832
+ "path": "vector",
833
+ "similarity": "cosine", # Options: euclidean, cosine, dotProduct
834
+ }
835
+ ]
836
+ },
837
+ name="vector_knn_index",
838
+ type="vectorSearch",
839
+ )
840
+
841
+ collection.create_search_index(search_index_model)
842
+ logger.info("Vector index created successfully.")
843
+
844
+ except PyMongoError as _:
845
+ logger.debug("vector index already exist")
846
+
847
+ async def upsert(self, data: dict[str, dict]):
848
+ logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
849
+ if not data:
850
+ logger.warning("You are inserting an empty data set to vector DB")
851
+ return []
852
+
853
+ list_data = [
854
+ {
855
+ "_id": k,
856
+ **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
857
+ }
858
+ for k, v in data.items()
859
+ ]
860
+ contents = [v["content"] for v in data.values()]
861
+ batches = [
862
+ contents[i : i + self._max_batch_size]
863
+ for i in range(0, len(contents), self._max_batch_size)
864
+ ]
865
+
866
+ async def wrapped_task(batch):
867
+ result = await self.embedding_func(batch)
868
+ pbar.update(1)
869
+ return result
870
+
871
+ embedding_tasks = [wrapped_task(batch) for batch in batches]
872
+ pbar = tqdm_async(
873
+ total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
874
+ )
875
+ embeddings_list = await asyncio.gather(*embedding_tasks)
876
+
877
+ embeddings = np.concatenate(embeddings_list)
878
+ for i, d in enumerate(list_data):
879
+ d["vector"] = np.array(embeddings[i], dtype=np.float32).tolist()
880
+
881
+ update_tasks = []
882
+ for doc in list_data:
883
+ update_tasks.append(
884
+ self._data.update_one({"_id": doc["_id"]}, {"$set": doc}, upsert=True)
885
+ )
886
+ await asyncio.gather(*update_tasks)
887
+
888
+ return list_data
889
+
890
+ async def query(self, query, top_k=5):
891
+ """Queries the vector database using Atlas Vector Search."""
892
+ # Generate the embedding
893
+ embedding = await self.embedding_func([query])
894
+
895
+ # Convert numpy array to a list to ensure compatibility with MongoDB
896
+ query_vector = embedding[0].tolist()
897
+
898
+ # Define the aggregation pipeline with the converted query vector
899
+ pipeline = [
900
+ {
901
+ "$vectorSearch": {
902
+ "index": "vector_knn_index", # Ensure this matches the created index name
903
+ "path": "vector",
904
+ "queryVector": query_vector,
905
+ "numCandidates": 100, # Adjust for performance
906
+ "limit": top_k,
907
+ }
908
+ },
909
+ {"$addFields": {"score": {"$meta": "vectorSearchScore"}}},
910
+ {"$match": {"score": {"$gte": self.cosine_better_than_threshold}}},
911
+ {"$project": {"vector": 0}},
912
+ ]
913
+
914
+ # Execute the aggregation pipeline
915
+ cursor = self._data.aggregate(pipeline)
916
+ results = await cursor.to_list()
917
+
918
+ # Format and return the results
919
+ return [
920
+ {**doc, "id": doc["_id"], "distance": doc.get("score", None)}
921
+ for doc in results
922
+ ]
923
+
924
+
925
+ def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
926
+ """Check if the collection exists. if not, create it."""
927
+ client = MongoClient(uri)
928
+ database = client.get_database(database_name)
929
+
930
+ collection_names = database.list_collection_names()
931
+
932
+ if collection_name not in collection_names:
933
+ database.create_collection(collection_name)
934
+ logger.info(f"Created collection: {collection_name}")
935
+ else:
936
+ logger.debug(f"Collection '{collection_name}' already exists.")
lightrag/lightrag.py CHANGED
@@ -76,6 +76,7 @@ STORAGE_IMPLEMENTATIONS = {
76
  "FaissVectorDBStorage",
77
  "QdrantVectorDBStorage",
78
  "OracleVectorDBStorage",
 
79
  ],
80
  "required_methods": ["query", "upsert"],
81
  },
@@ -140,6 +141,7 @@ STORAGE_ENV_REQUIREMENTS = {
140
  "ORACLE_PASSWORD",
141
  "ORACLE_CONFIG_DIR",
142
  ],
 
143
  # Document Status Storage Implementations
144
  "JsonDocStatusStorage": [],
145
  "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
@@ -160,6 +162,7 @@ STORAGES = {
160
  "MongoKVStorage": ".kg.mongo_impl",
161
  "MongoDocStatusStorage": ".kg.mongo_impl",
162
  "MongoGraphStorage": ".kg.mongo_impl",
 
163
  "RedisKVStorage": ".kg.redis_impl",
164
  "ChromaVectorDBStorage": ".kg.chroma_impl",
165
  "TiDBKVStorage": ".kg.tidb_impl",
 
76
  "FaissVectorDBStorage",
77
  "QdrantVectorDBStorage",
78
  "OracleVectorDBStorage",
79
+ "MongoVectorDBStorage",
80
  ],
81
  "required_methods": ["query", "upsert"],
82
  },
 
141
  "ORACLE_PASSWORD",
142
  "ORACLE_CONFIG_DIR",
143
  ],
144
+ "MongoVectorDBStorage": [],
145
  # Document Status Storage Implementations
146
  "JsonDocStatusStorage": [],
147
  "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
 
162
  "MongoKVStorage": ".kg.mongo_impl",
163
  "MongoDocStatusStorage": ".kg.mongo_impl",
164
  "MongoGraphStorage": ".kg.mongo_impl",
165
+ "MongoVectorDBStorage": ".kg.mongo_impl",
166
  "RedisKVStorage": ".kg.redis_impl",
167
  "ChromaVectorDBStorage": ".kg.chroma_impl",
168
  "TiDBKVStorage": ".kg.tidb_impl",