yangdx commited on
Commit
8aa0a5e
·
1 Parent(s): c6b21a9

refactor: make cosine similarity threshold a required config parameter

Browse files

• Remove default threshold from env var
• Add validation for missing threshold
• Move default to lightrag.py config init
• Update all vector DB implementations
• Improve threshold validation consistency

lightrag/kg/chroma_impl.py CHANGED
@@ -13,15 +13,15 @@ from lightrag.utils import logger
13
  class ChromaVectorDBStorage(BaseVectorStorage):
14
  """ChromaDB vector storage implementation."""
15
 
16
- cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
17
 
18
  def __post_init__(self):
19
  try:
20
- # Use global config value if specified, otherwise use default
21
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
22
- self.cosine_better_than_threshold = config.get(
23
- "cosine_better_than_threshold", self.cosine_better_than_threshold
24
- )
 
25
 
26
  user_collection_settings = config.get("collection_settings", {})
27
  # Default HNSW index settings for ChromaDB
 
13
  class ChromaVectorDBStorage(BaseVectorStorage):
14
  """ChromaDB vector storage implementation."""
15
 
16
+ cosine_better_than_threshold: float = None
17
 
18
  def __post_init__(self):
19
  try:
 
20
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
21
+ cosine_threshold = config.get("cosine_better_than_threshold")
22
+ if cosine_threshold is None:
23
+ raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
24
+ self.cosine_better_than_threshold = cosine_threshold
25
 
26
  user_collection_settings = config.get("collection_settings", {})
27
  # Default HNSW index settings for ChromaDB
lightrag/kg/faiss_impl.py CHANGED
@@ -23,14 +23,15 @@ class FaissVectorDBStorage(BaseVectorStorage):
23
  Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
24
  """
25
 
26
- cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
27
 
28
  def __post_init__(self):
29
  # Grab config values if available
30
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
31
- self.cosine_better_than_threshold = config.get(
32
- "cosine_better_than_threshold", self.cosine_better_than_threshold
33
- )
 
34
 
35
  # Where to save index file if you want persistent storage
36
  self._faiss_index_file = os.path.join(
 
23
  Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
24
  """
25
 
26
+ cosine_better_than_threshold: float = None
27
 
28
  def __post_init__(self):
29
  # Grab config values if available
30
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
31
+ cosine_threshold = config.get("cosine_better_than_threshold")
32
+ if cosine_threshold is None:
33
+ raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
34
+ self.cosine_better_than_threshold = cosine_threshold
35
 
36
  # Where to save index file if you want persistent storage
37
  self._faiss_index_file = os.path.join(
lightrag/kg/milvus_impl.py CHANGED
@@ -19,6 +19,8 @@ config.read("config.ini", "utf-8")
19
 
20
  @dataclass
21
  class MilvusVectorDBStorge(BaseVectorStorage):
 
 
22
  @staticmethod
23
  def create_collection_if_not_exist(
24
  client: MilvusClient, collection_name: str, **kwargs
@@ -30,6 +32,12 @@ class MilvusVectorDBStorge(BaseVectorStorage):
30
  )
31
 
32
  def __post_init__(self):
 
 
 
 
 
 
33
  self._client = MilvusClient(
34
  uri=os.environ.get(
35
  "MILVUS_URI",
@@ -103,7 +111,7 @@ class MilvusVectorDBStorge(BaseVectorStorage):
103
  data=embedding,
104
  limit=top_k,
105
  output_fields=list(self.meta_fields),
106
- search_params={"metric_type": "COSINE", "params": {"radius": 0.2}},
107
  )
108
  print(results)
109
  return [
 
19
 
20
  @dataclass
21
  class MilvusVectorDBStorge(BaseVectorStorage):
22
+ cosine_better_than_threshold: float = None
23
+
24
  @staticmethod
25
  def create_collection_if_not_exist(
26
  client: MilvusClient, collection_name: str, **kwargs
 
32
  )
33
 
34
  def __post_init__(self):
35
+ config = self.global_config.get("vector_db_storage_cls_kwargs", {})
36
+ cosine_threshold = config.get("cosine_better_than_threshold")
37
+ if cosine_threshold is None:
38
+ raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
39
+ self.cosine_better_than_threshold = cosine_threshold
40
+
41
  self._client = MilvusClient(
42
  uri=os.environ.get(
43
  "MILVUS_URI",
 
111
  data=embedding,
112
  limit=top_k,
113
  output_fields=list(self.meta_fields),
114
+ search_params={"metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}},
115
  )
116
  print(results)
117
  return [
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -73,16 +73,17 @@ from lightrag.base import (
73
 
74
  @dataclass
75
  class NanoVectorDBStorage(BaseVectorStorage):
76
- cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
77
 
78
  def __post_init__(self):
79
  # Initialize lock only for file operations
80
  self._save_lock = asyncio.Lock()
81
  # Use global config value if specified, otherwise use default
82
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
83
- self.cosine_better_than_threshold = config.get(
84
- "cosine_better_than_threshold", self.cosine_better_than_threshold
85
- )
 
86
 
87
  self._client_file_name = os.path.join(
88
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
 
73
 
74
  @dataclass
75
  class NanoVectorDBStorage(BaseVectorStorage):
76
+ cosine_better_than_threshold: float = None
77
 
78
  def __post_init__(self):
79
  # Initialize lock only for file operations
80
  self._save_lock = asyncio.Lock()
81
  # Use global config value if specified, otherwise use default
82
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
83
+ cosine_threshold = config.get("cosine_better_than_threshold")
84
+ if cosine_threshold is None:
85
+ raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
86
+ self.cosine_better_than_threshold = cosine_threshold
87
 
88
  self._client_file_name = os.path.join(
89
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
lightrag/kg/oracle_impl.py CHANGED
@@ -320,14 +320,14 @@ class OracleKVStorage(BaseKVStorage):
320
  class OracleVectorDBStorage(BaseVectorStorage):
321
  # db instance must be injected before use
322
  # db: OracleDB
323
- cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
324
 
325
  def __post_init__(self):
326
- # Use global config value if specified, otherwise use default
327
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
328
- self.cosine_better_than_threshold = config.get(
329
- "cosine_better_than_threshold", self.cosine_better_than_threshold
330
- )
 
331
 
332
  async def upsert(self, data: dict[str, dict]):
333
  """向向量数据库中插入数据"""
 
320
  class OracleVectorDBStorage(BaseVectorStorage):
321
  # db instance must be injected before use
322
  # db: OracleDB
323
+ cosine_better_than_threshold: float = None
324
 
325
  def __post_init__(self):
 
326
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
327
+ cosine_threshold = config.get("cosine_better_than_threshold")
328
+ if cosine_threshold is None:
329
+ raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
330
+ self.cosine_better_than_threshold = cosine_threshold
331
 
332
  async def upsert(self, data: dict[str, dict]):
333
  """向向量数据库中插入数据"""
lightrag/kg/postgres_impl.py CHANGED
@@ -299,15 +299,15 @@ class PGKVStorage(BaseKVStorage):
299
  class PGVectorStorage(BaseVectorStorage):
300
  # db instance must be injected before use
301
  # db: PostgreSQLDB
302
- cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
303
 
304
  def __post_init__(self):
305
  self._max_batch_size = self.global_config["embedding_batch_num"]
306
- # Use global config value if specified, otherwise use default
307
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
308
- self.cosine_better_than_threshold = config.get(
309
- "cosine_better_than_threshold", self.cosine_better_than_threshold
310
- )
 
311
 
312
  def _upsert_chunks(self, item: dict):
313
  try:
 
299
  class PGVectorStorage(BaseVectorStorage):
300
  # db instance must be injected before use
301
  # db: PostgreSQLDB
302
+ cosine_better_than_threshold: float = None
303
 
304
  def __post_init__(self):
305
  self._max_batch_size = self.global_config["embedding_batch_num"]
 
306
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
307
+ cosine_threshold = config.get("cosine_better_than_threshold")
308
+ if cosine_threshold is None:
309
+ raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
310
+ self.cosine_better_than_threshold = cosine_threshold
311
 
312
  def _upsert_chunks(self, item: dict):
313
  try:
lightrag/kg/qdrant_impl.py CHANGED
@@ -50,6 +50,8 @@ def compute_mdhash_id_for_qdrant(
50
 
51
  @dataclass
52
  class QdrantVectorDBStorage(BaseVectorStorage):
 
 
53
  @staticmethod
54
  def create_collection_if_not_exist(
55
  client: QdrantClient, collection_name: str, **kwargs
@@ -59,6 +61,12 @@ class QdrantVectorDBStorage(BaseVectorStorage):
59
  client.create_collection(collection_name, **kwargs)
60
 
61
  def __post_init__(self):
 
 
 
 
 
 
62
  self._client = QdrantClient(
63
  url=os.environ.get(
64
  "QDRANT_URL", config.get("qdrant", "uri", fallback=None)
@@ -131,4 +139,6 @@ class QdrantVectorDBStorage(BaseVectorStorage):
131
  with_payload=True,
132
  )
133
  logger.debug(f"query result: {results}")
134
- return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
 
 
 
50
 
51
  @dataclass
52
  class QdrantVectorDBStorage(BaseVectorStorage):
53
+ cosine_better_than_threshold: float = None
54
+
55
  @staticmethod
56
  def create_collection_if_not_exist(
57
  client: QdrantClient, collection_name: str, **kwargs
 
61
  client.create_collection(collection_name, **kwargs)
62
 
63
  def __post_init__(self):
64
+ config = self.global_config.get("vector_db_storage_cls_kwargs", {})
65
+ cosine_threshold = config.get("cosine_better_than_threshold")
66
+ if cosine_threshold is None:
67
+ raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
68
+ self.cosine_better_than_threshold = cosine_threshold
69
+
70
  self._client = QdrantClient(
71
  url=os.environ.get(
72
  "QDRANT_URL", config.get("qdrant", "uri", fallback=None)
 
139
  with_payload=True,
140
  )
141
  logger.debug(f"query result: {results}")
142
+ # 添加余弦相似度过滤
143
+ filtered_results = [dp for dp in results if dp.score >= self.cosine_better_than_threshold]
144
+ return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results]
lightrag/kg/tidb_impl.py CHANGED
@@ -212,18 +212,18 @@ class TiDBKVStorage(BaseKVStorage):
212
  class TiDBVectorDBStorage(BaseVectorStorage):
213
  # db instance must be injected before use
214
  # db: TiDB
215
- cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
216
 
217
  def __post_init__(self):
218
  self._client_file_name = os.path.join(
219
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
220
  )
221
  self._max_batch_size = self.global_config["embedding_batch_num"]
222
- # Use global config value if specified, otherwise use default
223
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
224
- self.cosine_better_than_threshold = config.get(
225
- "cosine_better_than_threshold", self.cosine_better_than_threshold
226
- )
 
227
 
228
  async def query(self, query: str, top_k: int) -> list[dict]:
229
  """Search from tidb vector"""
 
212
  class TiDBVectorDBStorage(BaseVectorStorage):
213
  # db instance must be injected before use
214
  # db: TiDB
215
+ cosine_better_than_threshold: float = None
216
 
217
  def __post_init__(self):
218
  self._client_file_name = os.path.join(
219
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
220
  )
221
  self._max_batch_size = self.global_config["embedding_batch_num"]
 
222
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
223
+ cosine_threshold = config.get("cosine_better_than_threshold")
224
+ if cosine_threshold is None:
225
+ raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs")
226
+ self.cosine_better_than_threshold = cosine_threshold
227
 
228
  async def query(self, query: str, top_k: int) -> list[dict]:
229
  """Search from tidb vector"""
lightrag/lightrag.py CHANGED
@@ -420,6 +420,15 @@ class LightRAG:
420
  # Check environment variables
421
  self.check_storage_env_vars(storage_name)
422
 
 
 
 
 
 
 
 
 
 
423
  # show config
424
  global_config = asdict(self)
425
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
 
420
  # Check environment variables
421
  self.check_storage_env_vars(storage_name)
422
 
423
+ # Ensure vector_db_storage_cls_kwargs has required fields
424
+ default_vector_db_kwargs = {
425
+ "cosine_better_than_threshold": float(os.getenv("COSINE_THRESHOLD", "0.2"))
426
+ }
427
+ self.vector_db_storage_cls_kwargs = {
428
+ **default_vector_db_kwargs,
429
+ **self.vector_db_storage_cls_kwargs
430
+ }
431
+
432
  # show config
433
  global_config = asdict(self)
434
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])