YanSte commited on
Commit
6cc34aa
·
1 Parent(s): 4ff677c

cleaned import

Browse files
lightrag/base.py CHANGED
@@ -1,6 +1,8 @@
 
1
  import os
2
  from dataclasses import dataclass, field
3
  from typing import (
 
4
  TypedDict,
5
  Union,
6
  Literal,
@@ -8,6 +10,8 @@ from typing import (
8
  Any,
9
  )
10
 
 
 
11
 
12
  from .utils import EmbeddingFunc
13
 
@@ -99,9 +103,7 @@ class BaseKVStorage(StorageNameSpace):
99
  async def drop(self) -> None:
100
  raise NotImplementedError
101
 
102
- async def get_by_status(
103
- self, status: str
104
- ) -> Union[list[dict[str, Any]], None]:
105
  raise NotImplementedError
106
 
107
 
@@ -148,12 +150,12 @@ class BaseGraphStorage(StorageNameSpace):
148
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
149
  raise NotImplementedError("Node embedding is not used in lightrag.")
150
 
151
- async def get_all_labels(self) -> List[str]:
152
  raise NotImplementedError
153
 
154
  async def get_knowledge_graph(
155
  self, node_label: str, max_depth: int = 5
156
- ) -> Dict[str, List[Dict]]:
157
  raise NotImplementedError
158
 
159
 
@@ -177,20 +179,20 @@ class DocProcessingStatus:
177
  updated_at: str # ISO format timestamp
178
  chunks_count: Optional[int] = None # Number of chunks after splitting
179
  error: Optional[str] = None # Error message if failed
180
- metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata
181
 
182
 
183
  class DocStatusStorage(BaseKVStorage):
184
  """Base class for document status storage"""
185
 
186
- async def get_status_counts(self) -> Dict[str, int]:
187
  """Get counts of documents in each status"""
188
  raise NotImplementedError
189
 
190
- async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
191
  """Get all failed documents"""
192
  raise NotImplementedError
193
 
194
- async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
195
  """Get all pending documents"""
196
  raise NotImplementedError
 
1
+ from enum import Enum
2
  import os
3
  from dataclasses import dataclass, field
4
  from typing import (
5
+ Optional,
6
  TypedDict,
7
  Union,
8
  Literal,
 
10
  Any,
11
  )
12
 
13
+ import numpy as np
14
+
15
 
16
  from .utils import EmbeddingFunc
17
 
 
103
  async def drop(self) -> None:
104
  raise NotImplementedError
105
 
106
+ async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
 
 
107
  raise NotImplementedError
108
 
109
 
 
150
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
151
  raise NotImplementedError("Node embedding is not used in lightrag.")
152
 
153
+ async def get_all_labels(self) -> list[str]:
154
  raise NotImplementedError
155
 
156
  async def get_knowledge_graph(
157
  self, node_label: str, max_depth: int = 5
158
+ ) -> dict[str, list[dict]]:
159
  raise NotImplementedError
160
 
161
 
 
179
  updated_at: str # ISO format timestamp
180
  chunks_count: Optional[int] = None # Number of chunks after splitting
181
  error: Optional[str] = None # Error message if failed
182
+ metadata: dict[str, Any] = field(default_factory=dict) # Additional metadata
183
 
184
 
185
  class DocStatusStorage(BaseKVStorage):
186
  """Base class for document status storage"""
187
 
188
+ async def get_status_counts(self) -> dict[str, int]:
189
  """Get counts of documents in each status"""
190
  raise NotImplementedError
191
 
192
+ async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
193
  """Get all failed documents"""
194
  raise NotImplementedError
195
 
196
+ async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
197
  """Get all pending documents"""
198
  raise NotImplementedError
lightrag/kg/json_kv_impl.py CHANGED
@@ -51,8 +51,6 @@ class JsonKVStorage(BaseKVStorage):
51
  async def drop(self) -> None:
52
  self._data = {}
53
 
54
- async def get_by_status(
55
- self, status: str
56
- ) -> Union[list[dict[str, Any]], None]:
57
  result = [v for _, v in self._data.items() if v["status"] == status]
58
  return result if result else None
 
51
  async def drop(self) -> None:
52
  self._data = {}
53
 
54
+ async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
 
 
55
  result = [v for _, v in self._data.items() if v["status"] == status]
56
  return result if result else None
lightrag/kg/mongo_impl.py CHANGED
@@ -77,9 +77,7 @@ class MongoKVStorage(BaseKVStorage):
77
  """Drop the collection"""
78
  await self._data.drop()
79
 
80
- async def get_by_status(
81
- self, status: str
82
- ) -> Union[list[dict[str, Any]], None]:
83
  """Get documents by status and ids"""
84
  return self._data.find({"status": status})
85
 
 
77
  """Drop the collection"""
78
  await self._data.drop()
79
 
80
+ async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
 
 
81
  """Get documents by status and ids"""
82
  return self._data.find({"status": status})
83
 
lightrag/kg/oracle_impl.py CHANGED
@@ -229,9 +229,7 @@ class OracleKVStorage(BaseKVStorage):
229
  res = [{k: v} for k, v in dict_res.items()]
230
  return res
231
 
232
- async def get_by_status(
233
- self, status: str
234
- ) -> Union[list[dict[str, Any]], None]:
235
  """Specifically for llm_response_cache."""
236
  SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
237
  params = {"workspace": self.db.workspace, "status": status}
 
229
  res = [{k: v} for k, v in dict_res.items()]
230
  return res
231
 
232
+ async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
 
 
233
  """Specifically for llm_response_cache."""
234
  SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
235
  params = {"workspace": self.db.workspace, "status": status}
lightrag/kg/postgres_impl.py CHANGED
@@ -231,9 +231,7 @@ class PGKVStorage(BaseKVStorage):
231
  else:
232
  return await self.db.query(sql, params, multirows=True)
233
 
234
- async def get_by_status(
235
- self, status: str
236
- ) -> Union[list[dict[str, Any]], None]:
237
  """Specifically for llm_response_cache."""
238
  SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
239
  params = {"workspace": self.db.workspace, "status": status}
 
231
  else:
232
  return await self.db.query(sql, params, multirows=True)
233
 
234
+ async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
 
 
235
  """Specifically for llm_response_cache."""
236
  SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
237
  params = {"workspace": self.db.workspace, "status": status}
lightrag/kg/redis_impl.py CHANGED
@@ -59,9 +59,7 @@ class RedisKVStorage(BaseKVStorage):
59
  if keys:
60
  await self._redis.delete(*keys)
61
 
62
- async def get_by_status(
63
- self, status: str
64
- ) -> Union[list[dict[str, Any]], None]:
65
  pipe = self._redis.pipeline()
66
  for key in await self._redis.keys(f"{self.namespace}:*"):
67
  pipe.hgetall(key)
 
59
  if keys:
60
  await self._redis.delete(*keys)
61
 
62
+ async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
 
 
63
  pipe = self._redis.pipeline()
64
  for key in await self._redis.keys(f"{self.namespace}:*"):
65
  pipe.hgetall(key)
lightrag/kg/tidb_impl.py CHANGED
@@ -322,9 +322,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
322
  merge_sql = SQL_TEMPLATES["insert_relationship"]
323
  await self.db.execute(merge_sql, data)
324
 
325
- async def get_by_status(
326
- self, status: str
327
- ) -> Union[list[dict[str, Any]], None]:
328
  SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
329
  params = {"workspace": self.db.workspace, "status": status}
330
  return await self.db.query(SQL, params, multirows=True)
 
322
  merge_sql = SQL_TEMPLATES["insert_relationship"]
323
  await self.db.execute(merge_sql, data)
324
 
325
+ async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
 
 
326
  SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
327
  params = {"workspace": self.db.workspace, "status": status}
328
  return await self.db.query(SQL, params, multirows=True)
lightrag/lightrag.py CHANGED
@@ -4,11 +4,16 @@ from tqdm.asyncio import tqdm as tqdm_async
4
  from dataclasses import asdict, dataclass, field
5
  from datetime import datetime
6
  from functools import partial
7
- from typing import Any, Type, Union
8
  import traceback
9
  from .operate import (
10
  chunking_by_token_size,
11
- extract_entities
 
 
 
 
 
12
  # local_query,global_query,hybrid_query,,
13
  )
14
 
@@ -19,18 +24,21 @@ from .utils import (
19
  convert_response_to_json,
20
  logger,
21
  set_logger,
22
- statistic_data
23
  )
24
  from .base import (
25
  BaseGraphStorage,
26
  BaseKVStorage,
27
  BaseVectorStorage,
28
  DocStatus,
 
 
29
  )
30
 
31
  from .namespace import NameSpace, make_namespace
32
 
33
  from .prompt import GRAPH_FIELD_SEP
 
34
  STORAGES = {
35
  "NetworkXStorage": ".kg.networkx_impl",
36
  "JsonKVStorage": ".kg.json_kv_impl",
@@ -351,9 +359,10 @@ class LightRAG:
351
  )
352
 
353
  async def ainsert(
354
- self, string_or_strings: Union[str, list[str]],
355
- split_by_character: str | None = None,
356
- split_by_character_only: bool = False
 
357
  ):
358
  """Insert documents with checkpoint support
359
 
@@ -368,7 +377,6 @@ class LightRAG:
368
  await self.apipeline_process_chunks(split_by_character, split_by_character_only)
369
  await self.apipeline_process_extract_graph()
370
 
371
-
372
  def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
373
  loop = always_get_an_event_loop()
374
  return loop.run_until_complete(
@@ -482,31 +490,27 @@ class LightRAG:
482
  logger.info(f"Stored {len(new_docs)} new unique documents")
483
 
484
  async def apipeline_process_chunks(
485
- self,
486
- split_by_character: str | None = None,
487
- split_by_character_only: bool = False
488
- ) -> None:
489
  """Get pendding documents, split into chunks,insert chunks"""
490
  # 1. get all pending and failed documents
491
  to_process_doc_keys: list[str] = []
492
 
493
  # Process failes
494
- to_process_docs = await self.full_docs.get_by_status(
495
- status=DocStatus.FAILED
496
- )
497
  if to_process_docs:
498
  to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
499
-
500
  # Process Pending
501
- to_process_docs = await self.full_docs.get_by_status(
502
- status=DocStatus.PENDING
503
- )
504
  if to_process_docs:
505
  to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
506
 
507
  if not to_process_doc_keys:
508
  logger.info("All documents have been processed or are duplicates")
509
- return
510
 
511
  full_docs_ids = await self.full_docs.get_by_ids(to_process_doc_keys)
512
  new_docs = {}
@@ -515,8 +519,8 @@ class LightRAG:
515
 
516
  if not new_docs:
517
  logger.info("All documents have been processed or are duplicates")
518
- return
519
-
520
  # 2. split docs into chunks, insert chunks, update doc status
521
  batch_size = self.addon_params.get("insert_batch_size", 10)
522
  for i in range(0, len(new_docs), batch_size):
@@ -526,11 +530,11 @@ class LightRAG:
526
  batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
527
  ):
528
  doc_status: dict[str, Any] = {
529
- "content_summary": doc["content_summary"],
530
- "content_length": doc["content_length"],
531
- "status": DocStatus.PROCESSING,
532
- "created_at": doc["created_at"],
533
- "updated_at": datetime.now().isoformat(),
534
  }
535
  try:
536
  await self.doc_status.upsert({doc_id: doc_status})
@@ -564,14 +568,16 @@ class LightRAG:
564
 
565
  except Exception as e:
566
  doc_status.update(
567
- {
568
- "status": DocStatus.FAILED,
569
- "error": str(e),
570
- "updated_at": datetime.now().isoformat(),
571
- }
572
- )
573
  await self.doc_status.upsert({doc_id: doc_status})
574
- logger.error(f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}")
 
 
575
  continue
576
 
577
  async def apipeline_process_extract_graph(self):
@@ -580,22 +586,18 @@ class LightRAG:
580
  to_process_doc_keys: list[str] = []
581
 
582
  # Process failes
583
- to_process_docs = await self.full_docs.get_by_status(
584
- status=DocStatus.FAILED
585
- )
586
  if to_process_docs:
587
  to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
588
-
589
  # Process Pending
590
- to_process_docs = await self.full_docs.get_by_status(
591
- status=DocStatus.PENDING
592
- )
593
  if to_process_docs:
594
  to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
595
 
596
  if not to_process_doc_keys:
597
  logger.info("All documents have been processed or are duplicates")
598
- return
599
 
600
  # Process documents in batches
601
  batch_size = self.addon_params.get("insert_batch_size", 10)
@@ -606,7 +608,7 @@ class LightRAG:
606
 
607
  async def process_chunk(chunk_id: str):
608
  async with semaphore:
609
- chunks:dict[str, Any] = {
610
  i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
611
  }
612
  # Extract and store entities and relationships
@@ -1051,7 +1053,7 @@ class LightRAG:
1051
  return content
1052
  return content[:max_length] + "..."
1053
 
1054
- async def get_processing_status(self) -> Dict[str, int]:
1055
  """Get current document processing status counts
1056
 
1057
  Returns:
 
4
  from dataclasses import asdict, dataclass, field
5
  from datetime import datetime
6
  from functools import partial
7
+ from typing import Any, Type, Union, cast
8
  import traceback
9
  from .operate import (
10
  chunking_by_token_size,
11
+ extract_entities,
12
+ extract_keywords_only,
13
+ kg_query,
14
+ kg_query_with_keywords,
15
+ mix_kg_vector_query,
16
+ naive_query,
17
  # local_query,global_query,hybrid_query,,
18
  )
19
 
 
24
  convert_response_to_json,
25
  logger,
26
  set_logger,
27
+ statistic_data,
28
  )
29
  from .base import (
30
  BaseGraphStorage,
31
  BaseKVStorage,
32
  BaseVectorStorage,
33
  DocStatus,
34
+ QueryParam,
35
+ StorageNameSpace,
36
  )
37
 
38
  from .namespace import NameSpace, make_namespace
39
 
40
  from .prompt import GRAPH_FIELD_SEP
41
+
42
  STORAGES = {
43
  "NetworkXStorage": ".kg.networkx_impl",
44
  "JsonKVStorage": ".kg.json_kv_impl",
 
359
  )
360
 
361
  async def ainsert(
362
+ self,
363
+ string_or_strings: Union[str, list[str]],
364
+ split_by_character: str | None = None,
365
+ split_by_character_only: bool = False,
366
  ):
367
  """Insert documents with checkpoint support
368
 
 
377
  await self.apipeline_process_chunks(split_by_character, split_by_character_only)
378
  await self.apipeline_process_extract_graph()
379
 
 
380
  def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
381
  loop = always_get_an_event_loop()
382
  return loop.run_until_complete(
 
490
  logger.info(f"Stored {len(new_docs)} new unique documents")
491
 
492
  async def apipeline_process_chunks(
493
+ self,
494
+ split_by_character: str | None = None,
495
+ split_by_character_only: bool = False,
496
+ ) -> None:
497
  """Get pendding documents, split into chunks,insert chunks"""
498
  # 1. get all pending and failed documents
499
  to_process_doc_keys: list[str] = []
500
 
501
  # Process failes
502
+ to_process_docs = await self.full_docs.get_by_status(status=DocStatus.FAILED)
 
 
503
  if to_process_docs:
504
  to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
505
+
506
  # Process Pending
507
+ to_process_docs = await self.full_docs.get_by_status(status=DocStatus.PENDING)
 
 
508
  if to_process_docs:
509
  to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
510
 
511
  if not to_process_doc_keys:
512
  logger.info("All documents have been processed or are duplicates")
513
+ return
514
 
515
  full_docs_ids = await self.full_docs.get_by_ids(to_process_doc_keys)
516
  new_docs = {}
 
519
 
520
  if not new_docs:
521
  logger.info("All documents have been processed or are duplicates")
522
+ return
523
+
524
  # 2. split docs into chunks, insert chunks, update doc status
525
  batch_size = self.addon_params.get("insert_batch_size", 10)
526
  for i in range(0, len(new_docs), batch_size):
 
530
  batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
531
  ):
532
  doc_status: dict[str, Any] = {
533
+ "content_summary": doc["content_summary"],
534
+ "content_length": doc["content_length"],
535
+ "status": DocStatus.PROCESSING,
536
+ "created_at": doc["created_at"],
537
+ "updated_at": datetime.now().isoformat(),
538
  }
539
  try:
540
  await self.doc_status.upsert({doc_id: doc_status})
 
568
 
569
  except Exception as e:
570
  doc_status.update(
571
+ {
572
+ "status": DocStatus.FAILED,
573
+ "error": str(e),
574
+ "updated_at": datetime.now().isoformat(),
575
+ }
576
+ )
577
  await self.doc_status.upsert({doc_id: doc_status})
578
+ logger.error(
579
+ f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
580
+ )
581
  continue
582
 
583
  async def apipeline_process_extract_graph(self):
 
586
  to_process_doc_keys: list[str] = []
587
 
588
  # Process failes
589
+ to_process_docs = await self.full_docs.get_by_status(status=DocStatus.FAILED)
 
 
590
  if to_process_docs:
591
  to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
592
+
593
  # Process Pending
594
+ to_process_docs = await self.full_docs.get_by_status(status=DocStatus.PENDING)
 
 
595
  if to_process_docs:
596
  to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
597
 
598
  if not to_process_doc_keys:
599
  logger.info("All documents have been processed or are duplicates")
600
+ return
601
 
602
  # Process documents in batches
603
  batch_size = self.addon_params.get("insert_batch_size", 10)
 
608
 
609
  async def process_chunk(chunk_id: str):
610
  async with semaphore:
611
+ chunks: dict[str, Any] = {
612
  i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
613
  }
614
  # Extract and store entities and relationships
 
1053
  return content
1054
  return content[:max_length] + "..."
1055
 
1056
+ async def get_processing_status(self) -> dict[str, int]:
1057
  """Get current document processing status counts
1058
 
1059
  Returns: