gzdaniel commited on
Commit
f29caf0
·
1 Parent(s): 6320c9d

Fix linting

Browse files
lightrag/kg/postgres_impl.py CHANGED
@@ -603,7 +603,7 @@ class PGKVStorage(BaseKVStorage):
603
 
604
  try:
605
  results = await self.db.query(sql, params, multirows=True)
606
-
607
  # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
608
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
609
  processed_results = {}
@@ -611,19 +611,21 @@ class PGKVStorage(BaseKVStorage):
611
  # Parse flattened key to extract cache_type
612
  key_parts = row["id"].split(":")
613
  cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown"
614
-
615
  # Map field names and add cache_type for compatibility
616
  processed_row = {
617
  **row,
618
- "return": row.get("return_value", ""), # Map return_value to return
 
 
619
  "cache_type": cache_type, # Add cache_type from key
620
  "original_prompt": row.get("original_prompt", ""),
621
  "chunk_id": row.get("chunk_id"),
622
- "mode": row.get("mode", "default")
623
  }
624
  processed_results[row["id"]] = processed_row
625
  return processed_results
626
-
627
  # For other namespaces, return as-is
628
  return {row["id"]: row for row in results}
629
  except Exception as e:
 
603
 
604
  try:
605
  results = await self.db.query(sql, params, multirows=True)
606
+
607
  # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
608
  if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
609
  processed_results = {}
 
611
  # Parse flattened key to extract cache_type
612
  key_parts = row["id"].split(":")
613
  cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown"
614
+
615
  # Map field names and add cache_type for compatibility
616
  processed_row = {
617
  **row,
618
+ "return": row.get(
619
+ "return_value", ""
620
+ ), # Map return_value to return
621
  "cache_type": cache_type, # Add cache_type from key
622
  "original_prompt": row.get("original_prompt", ""),
623
  "chunk_id": row.get("chunk_id"),
624
+ "mode": row.get("mode", "default"),
625
  }
626
  processed_results[row["id"]] = processed_row
627
  return processed_results
628
+
629
  # For other namespaces, return as-is
630
  return {row["id"]: row for row in results}
631
  except Exception as e:
lightrag/kg/redis_impl.py CHANGED
@@ -14,7 +14,12 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore
14
  from redis.exceptions import RedisError, ConnectionError # type: ignore
15
  from lightrag.utils import logger
16
 
17
- from lightrag.base import BaseKVStorage, DocStatusStorage, DocStatus, DocProcessingStatus
 
 
 
 
 
18
  import json
19
 
20
 
@@ -29,10 +34,10 @@ SOCKET_CONNECT_TIMEOUT = 3.0
29
 
30
  class RedisConnectionManager:
31
  """Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
32
-
33
  _pools = {}
34
  _lock = threading.Lock()
35
-
36
  @classmethod
37
  def get_pool(cls, redis_url: str) -> ConnectionPool:
38
  """Get or create a connection pool for the given Redis URL"""
@@ -48,7 +53,7 @@ class RedisConnectionManager:
48
  )
49
  logger.info(f"Created shared Redis connection pool for {redis_url}")
50
  return cls._pools[redis_url]
51
-
52
  @classmethod
53
  def close_all_pools(cls):
54
  """Close all connection pools (for cleanup)"""
@@ -254,17 +259,21 @@ class RedisKVStorage(BaseKVStorage):
254
  pattern = f"{self.namespace}:{mode}:*"
255
  cursor = 0
256
  mode_keys = []
257
-
258
  while True:
259
- cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
 
 
260
  if keys:
261
  mode_keys.extend(keys)
262
-
263
  if cursor == 0:
264
  break
265
-
266
  keys_to_delete.extend(mode_keys)
267
- logger.info(f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'")
 
 
268
 
269
  if keys_to_delete:
270
  # Batch delete
@@ -296,7 +305,7 @@ class RedisKVStorage(BaseKVStorage):
296
  pattern = f"{self.namespace}:*"
297
  cursor = 0
298
  deleted_count = 0
299
-
300
  while True:
301
  cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
302
  if keys:
@@ -306,7 +315,7 @@ class RedisKVStorage(BaseKVStorage):
306
  pipe.delete(key)
307
  results = await pipe.execute()
308
  deleted_count += sum(results)
309
-
310
  if cursor == 0:
311
  break
312
 
@@ -419,7 +428,9 @@ class RedisDocStatusStorage(DocStatusStorage):
419
  try:
420
  async with self._get_redis_connection() as redis:
421
  await redis.ping()
422
- logger.info(f"Connected to Redis for doc status namespace {self.namespace}")
 
 
423
  except Exception as e:
424
  logger.error(f"Failed to connect to Redis for doc status: {e}")
425
  raise
@@ -475,7 +486,7 @@ class RedisDocStatusStorage(DocStatusStorage):
475
  for id in ids:
476
  pipe.get(f"{self.namespace}:{id}")
477
  results = await pipe.execute()
478
-
479
  for result_data in results:
480
  if result_data:
481
  try:
@@ -495,14 +506,16 @@ class RedisDocStatusStorage(DocStatusStorage):
495
  # Use SCAN to iterate through all keys in the namespace
496
  cursor = 0
497
  while True:
498
- cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000)
 
 
499
  if keys:
500
  # Get all values in batch
501
  pipe = redis.pipeline()
502
  for key in keys:
503
  pipe.get(key)
504
  values = await pipe.execute()
505
-
506
  # Count statuses
507
  for value in values:
508
  if value:
@@ -513,12 +526,12 @@ class RedisDocStatusStorage(DocStatusStorage):
513
  counts[status] += 1
514
  except json.JSONDecodeError:
515
  continue
516
-
517
  if cursor == 0:
518
  break
519
  except Exception as e:
520
  logger.error(f"Error getting status counts: {e}")
521
-
522
  return counts
523
 
524
  async def get_docs_by_status(
@@ -531,14 +544,16 @@ class RedisDocStatusStorage(DocStatusStorage):
531
  # Use SCAN to iterate through all keys in the namespace
532
  cursor = 0
533
  while True:
534
- cursor, keys = await redis.scan(cursor, match=f"{self.namespace}:*", count=1000)
 
 
535
  if keys:
536
  # Get all values in batch
537
  pipe = redis.pipeline()
538
  for key in keys:
539
  pipe.get(key)
540
  values = await pipe.execute()
541
-
542
  # Filter by status and create DocProcessingStatus objects
543
  for key, value in zip(keys, values):
544
  if value:
@@ -547,26 +562,31 @@ class RedisDocStatusStorage(DocStatusStorage):
547
  if doc_data.get("status") == status.value:
548
  # Extract document ID from key
549
  doc_id = key.split(":", 1)[1]
550
-
551
  # Make a copy of the data to avoid modifying the original
552
  data = doc_data.copy()
553
  # If content is missing, use content_summary as content
554
- if "content" not in data and "content_summary" in data:
 
 
 
555
  data["content"] = data["content_summary"]
556
  # If file_path is not in data, use document id as file path
557
  if "file_path" not in data:
558
  data["file_path"] = "no-file-path"
559
-
560
  result[doc_id] = DocProcessingStatus(**data)
561
  except (json.JSONDecodeError, KeyError) as e:
562
- logger.error(f"Error processing document {key}: {e}")
 
 
563
  continue
564
-
565
  if cursor == 0:
566
  break
567
  except Exception as e:
568
  logger.error(f"Error getting docs by status: {e}")
569
-
570
  return result
571
 
572
  async def index_done_callback(self) -> None:
@@ -577,7 +597,7 @@ class RedisDocStatusStorage(DocStatusStorage):
577
  """Insert or update document status data"""
578
  if not data:
579
  return
580
-
581
  logger.debug(f"Inserting {len(data)} records to {self.namespace}")
582
  async with self._get_redis_connection() as redis:
583
  try:
@@ -602,15 +622,17 @@ class RedisDocStatusStorage(DocStatusStorage):
602
  """Delete specific records from storage by their IDs"""
603
  if not doc_ids:
604
  return
605
-
606
  async with self._get_redis_connection() as redis:
607
  pipe = redis.pipeline()
608
  for doc_id in doc_ids:
609
  pipe.delete(f"{self.namespace}:{doc_id}")
610
-
611
  results = await pipe.execute()
612
  deleted_count = sum(results)
613
- logger.info(f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}")
 
 
614
 
615
  async def drop(self) -> dict[str, str]:
616
  """Drop all document status data from storage and clean up resources"""
@@ -620,7 +642,7 @@ class RedisDocStatusStorage(DocStatusStorage):
620
  pattern = f"{self.namespace}:*"
621
  cursor = 0
622
  deleted_count = 0
623
-
624
  while True:
625
  cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
626
  if keys:
@@ -630,11 +652,13 @@ class RedisDocStatusStorage(DocStatusStorage):
630
  pipe.delete(key)
631
  results = await pipe.execute()
632
  deleted_count += sum(results)
633
-
634
  if cursor == 0:
635
  break
636
 
637
- logger.info(f"Dropped {deleted_count} doc status keys from {self.namespace}")
 
 
638
  return {"status": "success", "message": "data dropped"}
639
  except Exception as e:
640
  logger.error(f"Error dropping doc status {self.namespace}: {e}")
 
14
  from redis.exceptions import RedisError, ConnectionError # type: ignore
15
  from lightrag.utils import logger
16
 
17
+ from lightrag.base import (
18
+ BaseKVStorage,
19
+ DocStatusStorage,
20
+ DocStatus,
21
+ DocProcessingStatus,
22
+ )
23
  import json
24
 
25
 
 
34
 
35
  class RedisConnectionManager:
36
  """Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
37
+
38
  _pools = {}
39
  _lock = threading.Lock()
40
+
41
  @classmethod
42
  def get_pool(cls, redis_url: str) -> ConnectionPool:
43
  """Get or create a connection pool for the given Redis URL"""
 
53
  )
54
  logger.info(f"Created shared Redis connection pool for {redis_url}")
55
  return cls._pools[redis_url]
56
+
57
  @classmethod
58
  def close_all_pools(cls):
59
  """Close all connection pools (for cleanup)"""
 
259
  pattern = f"{self.namespace}:{mode}:*"
260
  cursor = 0
261
  mode_keys = []
262
+
263
  while True:
264
+ cursor, keys = await redis.scan(
265
+ cursor, match=pattern, count=1000
266
+ )
267
  if keys:
268
  mode_keys.extend(keys)
269
+
270
  if cursor == 0:
271
  break
272
+
273
  keys_to_delete.extend(mode_keys)
274
+ logger.info(
275
+ f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'"
276
+ )
277
 
278
  if keys_to_delete:
279
  # Batch delete
 
305
  pattern = f"{self.namespace}:*"
306
  cursor = 0
307
  deleted_count = 0
308
+
309
  while True:
310
  cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
311
  if keys:
 
315
  pipe.delete(key)
316
  results = await pipe.execute()
317
  deleted_count += sum(results)
318
+
319
  if cursor == 0:
320
  break
321
 
 
428
  try:
429
  async with self._get_redis_connection() as redis:
430
  await redis.ping()
431
+ logger.info(
432
+ f"Connected to Redis for doc status namespace {self.namespace}"
433
+ )
434
  except Exception as e:
435
  logger.error(f"Failed to connect to Redis for doc status: {e}")
436
  raise
 
486
  for id in ids:
487
  pipe.get(f"{self.namespace}:{id}")
488
  results = await pipe.execute()
489
+
490
  for result_data in results:
491
  if result_data:
492
  try:
 
506
  # Use SCAN to iterate through all keys in the namespace
507
  cursor = 0
508
  while True:
509
+ cursor, keys = await redis.scan(
510
+ cursor, match=f"{self.namespace}:*", count=1000
511
+ )
512
  if keys:
513
  # Get all values in batch
514
  pipe = redis.pipeline()
515
  for key in keys:
516
  pipe.get(key)
517
  values = await pipe.execute()
518
+
519
  # Count statuses
520
  for value in values:
521
  if value:
 
526
  counts[status] += 1
527
  except json.JSONDecodeError:
528
  continue
529
+
530
  if cursor == 0:
531
  break
532
  except Exception as e:
533
  logger.error(f"Error getting status counts: {e}")
534
+
535
  return counts
536
 
537
  async def get_docs_by_status(
 
544
  # Use SCAN to iterate through all keys in the namespace
545
  cursor = 0
546
  while True:
547
+ cursor, keys = await redis.scan(
548
+ cursor, match=f"{self.namespace}:*", count=1000
549
+ )
550
  if keys:
551
  # Get all values in batch
552
  pipe = redis.pipeline()
553
  for key in keys:
554
  pipe.get(key)
555
  values = await pipe.execute()
556
+
557
  # Filter by status and create DocProcessingStatus objects
558
  for key, value in zip(keys, values):
559
  if value:
 
562
  if doc_data.get("status") == status.value:
563
  # Extract document ID from key
564
  doc_id = key.split(":", 1)[1]
565
+
566
  # Make a copy of the data to avoid modifying the original
567
  data = doc_data.copy()
568
  # If content is missing, use content_summary as content
569
+ if (
570
+ "content" not in data
571
+ and "content_summary" in data
572
+ ):
573
  data["content"] = data["content_summary"]
574
  # If file_path is not in data, use document id as file path
575
  if "file_path" not in data:
576
  data["file_path"] = "no-file-path"
577
+
578
  result[doc_id] = DocProcessingStatus(**data)
579
  except (json.JSONDecodeError, KeyError) as e:
580
+ logger.error(
581
+ f"Error processing document {key}: {e}"
582
+ )
583
  continue
584
+
585
  if cursor == 0:
586
  break
587
  except Exception as e:
588
  logger.error(f"Error getting docs by status: {e}")
589
+
590
  return result
591
 
592
  async def index_done_callback(self) -> None:
 
597
  """Insert or update document status data"""
598
  if not data:
599
  return
600
+
601
  logger.debug(f"Inserting {len(data)} records to {self.namespace}")
602
  async with self._get_redis_connection() as redis:
603
  try:
 
622
  """Delete specific records from storage by their IDs"""
623
  if not doc_ids:
624
  return
625
+
626
  async with self._get_redis_connection() as redis:
627
  pipe = redis.pipeline()
628
  for doc_id in doc_ids:
629
  pipe.delete(f"{self.namespace}:{doc_id}")
630
+
631
  results = await pipe.execute()
632
  deleted_count = sum(results)
633
+ logger.info(
634
+ f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}"
635
+ )
636
 
637
  async def drop(self) -> dict[str, str]:
638
  """Drop all document status data from storage and clean up resources"""
 
642
  pattern = f"{self.namespace}:*"
643
  cursor = 0
644
  deleted_count = 0
645
+
646
  while True:
647
  cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
648
  if keys:
 
652
  pipe.delete(key)
653
  results = await pipe.execute()
654
  deleted_count += sum(results)
655
+
656
  if cursor == 0:
657
  break
658
 
659
+ logger.info(
660
+ f"Dropped {deleted_count} doc status keys from {self.namespace}"
661
+ )
662
  return {"status": "success", "message": "data dropped"}
663
  except Exception as e:
664
  logger.error(f"Error dropping doc status {self.namespace}: {e}")