Fix linting
Browse files- lightrag/kg/postgres_impl.py +7 -5
- lightrag/kg/redis_impl.py +56 -32
lightrag/kg/postgres_impl.py
CHANGED
@@ -603,7 +603,7 @@ class PGKVStorage(BaseKVStorage):
|
|
603 |
|
604 |
try:
|
605 |
results = await self.db.query(sql, params, multirows=True)
|
606 |
-
|
607 |
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
|
608 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
609 |
processed_results = {}
|
@@ -611,19 +611,21 @@ class PGKVStorage(BaseKVStorage):
|
|
611 |
# Parse flattened key to extract cache_type
|
612 |
key_parts = row["id"].split(":")
|
613 |
cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown"
|
614 |
-
|
615 |
# Map field names and add cache_type for compatibility
|
616 |
processed_row = {
|
617 |
**row,
|
618 |
-
"return": row.get(
|
|
|
|
|
619 |
"cache_type": cache_type, # Add cache_type from key
|
620 |
"original_prompt": row.get("original_prompt", ""),
|
621 |
"chunk_id": row.get("chunk_id"),
|
622 |
-
"mode": row.get("mode", "default")
|
623 |
}
|
624 |
processed_results[row["id"]] = processed_row
|
625 |
return processed_results
|
626 |
-
|
627 |
# For other namespaces, return as-is
|
628 |
return {row["id"]: row for row in results}
|
629 |
except Exception as e:
|
|
|
603 |
|
604 |
try:
|
605 |
results = await self.db.query(sql, params, multirows=True)
|
606 |
+
|
607 |
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
|
608 |
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
609 |
processed_results = {}
|
|
|
611 |
# Parse flattened key to extract cache_type
|
612 |
key_parts = row["id"].split(":")
|
613 |
cache_type = key_parts[1] if len(key_parts) >= 3 else "unknown"
|
614 |
+
|
615 |
# Map field names and add cache_type for compatibility
|
616 |
processed_row = {
|
617 |
**row,
|
618 |
+
"return": row.get(
|
619 |
+
"return_value", ""
|
620 |
+
), # Map return_value to return
|
621 |
"cache_type": cache_type, # Add cache_type from key
|
622 |
"original_prompt": row.get("original_prompt", ""),
|
623 |
"chunk_id": row.get("chunk_id"),
|
624 |
+
"mode": row.get("mode", "default"),
|
625 |
}
|
626 |
processed_results[row["id"]] = processed_row
|
627 |
return processed_results
|
628 |
+
|
629 |
# For other namespaces, return as-is
|
630 |
return {row["id"]: row for row in results}
|
631 |
except Exception as e:
|
lightrag/kg/redis_impl.py
CHANGED
@@ -14,7 +14,12 @@ from redis.asyncio import Redis, ConnectionPool # type: ignore
|
|
14 |
from redis.exceptions import RedisError, ConnectionError # type: ignore
|
15 |
from lightrag.utils import logger
|
16 |
|
17 |
-
from lightrag.base import
|
|
|
|
|
|
|
|
|
|
|
18 |
import json
|
19 |
|
20 |
|
@@ -29,10 +34,10 @@ SOCKET_CONNECT_TIMEOUT = 3.0
|
|
29 |
|
30 |
class RedisConnectionManager:
|
31 |
"""Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
|
32 |
-
|
33 |
_pools = {}
|
34 |
_lock = threading.Lock()
|
35 |
-
|
36 |
@classmethod
|
37 |
def get_pool(cls, redis_url: str) -> ConnectionPool:
|
38 |
"""Get or create a connection pool for the given Redis URL"""
|
@@ -48,7 +53,7 @@ class RedisConnectionManager:
|
|
48 |
)
|
49 |
logger.info(f"Created shared Redis connection pool for {redis_url}")
|
50 |
return cls._pools[redis_url]
|
51 |
-
|
52 |
@classmethod
|
53 |
def close_all_pools(cls):
|
54 |
"""Close all connection pools (for cleanup)"""
|
@@ -254,17 +259,21 @@ class RedisKVStorage(BaseKVStorage):
|
|
254 |
pattern = f"{self.namespace}:{mode}:*"
|
255 |
cursor = 0
|
256 |
mode_keys = []
|
257 |
-
|
258 |
while True:
|
259 |
-
cursor, keys = await redis.scan(
|
|
|
|
|
260 |
if keys:
|
261 |
mode_keys.extend(keys)
|
262 |
-
|
263 |
if cursor == 0:
|
264 |
break
|
265 |
-
|
266 |
keys_to_delete.extend(mode_keys)
|
267 |
-
logger.info(
|
|
|
|
|
268 |
|
269 |
if keys_to_delete:
|
270 |
# Batch delete
|
@@ -296,7 +305,7 @@ class RedisKVStorage(BaseKVStorage):
|
|
296 |
pattern = f"{self.namespace}:*"
|
297 |
cursor = 0
|
298 |
deleted_count = 0
|
299 |
-
|
300 |
while True:
|
301 |
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
302 |
if keys:
|
@@ -306,7 +315,7 @@ class RedisKVStorage(BaseKVStorage):
|
|
306 |
pipe.delete(key)
|
307 |
results = await pipe.execute()
|
308 |
deleted_count += sum(results)
|
309 |
-
|
310 |
if cursor == 0:
|
311 |
break
|
312 |
|
@@ -419,7 +428,9 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|
419 |
try:
|
420 |
async with self._get_redis_connection() as redis:
|
421 |
await redis.ping()
|
422 |
-
logger.info(
|
|
|
|
|
423 |
except Exception as e:
|
424 |
logger.error(f"Failed to connect to Redis for doc status: {e}")
|
425 |
raise
|
@@ -475,7 +486,7 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|
475 |
for id in ids:
|
476 |
pipe.get(f"{self.namespace}:{id}")
|
477 |
results = await pipe.execute()
|
478 |
-
|
479 |
for result_data in results:
|
480 |
if result_data:
|
481 |
try:
|
@@ -495,14 +506,16 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|
495 |
# Use SCAN to iterate through all keys in the namespace
|
496 |
cursor = 0
|
497 |
while True:
|
498 |
-
cursor, keys = await redis.scan(
|
|
|
|
|
499 |
if keys:
|
500 |
# Get all values in batch
|
501 |
pipe = redis.pipeline()
|
502 |
for key in keys:
|
503 |
pipe.get(key)
|
504 |
values = await pipe.execute()
|
505 |
-
|
506 |
# Count statuses
|
507 |
for value in values:
|
508 |
if value:
|
@@ -513,12 +526,12 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|
513 |
counts[status] += 1
|
514 |
except json.JSONDecodeError:
|
515 |
continue
|
516 |
-
|
517 |
if cursor == 0:
|
518 |
break
|
519 |
except Exception as e:
|
520 |
logger.error(f"Error getting status counts: {e}")
|
521 |
-
|
522 |
return counts
|
523 |
|
524 |
async def get_docs_by_status(
|
@@ -531,14 +544,16 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|
531 |
# Use SCAN to iterate through all keys in the namespace
|
532 |
cursor = 0
|
533 |
while True:
|
534 |
-
cursor, keys = await redis.scan(
|
|
|
|
|
535 |
if keys:
|
536 |
# Get all values in batch
|
537 |
pipe = redis.pipeline()
|
538 |
for key in keys:
|
539 |
pipe.get(key)
|
540 |
values = await pipe.execute()
|
541 |
-
|
542 |
# Filter by status and create DocProcessingStatus objects
|
543 |
for key, value in zip(keys, values):
|
544 |
if value:
|
@@ -547,26 +562,31 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|
547 |
if doc_data.get("status") == status.value:
|
548 |
# Extract document ID from key
|
549 |
doc_id = key.split(":", 1)[1]
|
550 |
-
|
551 |
# Make a copy of the data to avoid modifying the original
|
552 |
data = doc_data.copy()
|
553 |
# If content is missing, use content_summary as content
|
554 |
-
if
|
|
|
|
|
|
|
555 |
data["content"] = data["content_summary"]
|
556 |
# If file_path is not in data, use document id as file path
|
557 |
if "file_path" not in data:
|
558 |
data["file_path"] = "no-file-path"
|
559 |
-
|
560 |
result[doc_id] = DocProcessingStatus(**data)
|
561 |
except (json.JSONDecodeError, KeyError) as e:
|
562 |
-
logger.error(
|
|
|
|
|
563 |
continue
|
564 |
-
|
565 |
if cursor == 0:
|
566 |
break
|
567 |
except Exception as e:
|
568 |
logger.error(f"Error getting docs by status: {e}")
|
569 |
-
|
570 |
return result
|
571 |
|
572 |
async def index_done_callback(self) -> None:
|
@@ -577,7 +597,7 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|
577 |
"""Insert or update document status data"""
|
578 |
if not data:
|
579 |
return
|
580 |
-
|
581 |
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
|
582 |
async with self._get_redis_connection() as redis:
|
583 |
try:
|
@@ -602,15 +622,17 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|
602 |
"""Delete specific records from storage by their IDs"""
|
603 |
if not doc_ids:
|
604 |
return
|
605 |
-
|
606 |
async with self._get_redis_connection() as redis:
|
607 |
pipe = redis.pipeline()
|
608 |
for doc_id in doc_ids:
|
609 |
pipe.delete(f"{self.namespace}:{doc_id}")
|
610 |
-
|
611 |
results = await pipe.execute()
|
612 |
deleted_count = sum(results)
|
613 |
-
logger.info(
|
|
|
|
|
614 |
|
615 |
async def drop(self) -> dict[str, str]:
|
616 |
"""Drop all document status data from storage and clean up resources"""
|
@@ -620,7 +642,7 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|
620 |
pattern = f"{self.namespace}:*"
|
621 |
cursor = 0
|
622 |
deleted_count = 0
|
623 |
-
|
624 |
while True:
|
625 |
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
626 |
if keys:
|
@@ -630,11 +652,13 @@ class RedisDocStatusStorage(DocStatusStorage):
|
|
630 |
pipe.delete(key)
|
631 |
results = await pipe.execute()
|
632 |
deleted_count += sum(results)
|
633 |
-
|
634 |
if cursor == 0:
|
635 |
break
|
636 |
|
637 |
-
logger.info(
|
|
|
|
|
638 |
return {"status": "success", "message": "data dropped"}
|
639 |
except Exception as e:
|
640 |
logger.error(f"Error dropping doc status {self.namespace}: {e}")
|
|
|
14 |
from redis.exceptions import RedisError, ConnectionError # type: ignore
|
15 |
from lightrag.utils import logger
|
16 |
|
17 |
+
from lightrag.base import (
|
18 |
+
BaseKVStorage,
|
19 |
+
DocStatusStorage,
|
20 |
+
DocStatus,
|
21 |
+
DocProcessingStatus,
|
22 |
+
)
|
23 |
import json
|
24 |
|
25 |
|
|
|
34 |
|
35 |
class RedisConnectionManager:
|
36 |
"""Shared Redis connection pool manager to avoid creating multiple pools for the same Redis URI"""
|
37 |
+
|
38 |
_pools = {}
|
39 |
_lock = threading.Lock()
|
40 |
+
|
41 |
@classmethod
|
42 |
def get_pool(cls, redis_url: str) -> ConnectionPool:
|
43 |
"""Get or create a connection pool for the given Redis URL"""
|
|
|
53 |
)
|
54 |
logger.info(f"Created shared Redis connection pool for {redis_url}")
|
55 |
return cls._pools[redis_url]
|
56 |
+
|
57 |
@classmethod
|
58 |
def close_all_pools(cls):
|
59 |
"""Close all connection pools (for cleanup)"""
|
|
|
259 |
pattern = f"{self.namespace}:{mode}:*"
|
260 |
cursor = 0
|
261 |
mode_keys = []
|
262 |
+
|
263 |
while True:
|
264 |
+
cursor, keys = await redis.scan(
|
265 |
+
cursor, match=pattern, count=1000
|
266 |
+
)
|
267 |
if keys:
|
268 |
mode_keys.extend(keys)
|
269 |
+
|
270 |
if cursor == 0:
|
271 |
break
|
272 |
+
|
273 |
keys_to_delete.extend(mode_keys)
|
274 |
+
logger.info(
|
275 |
+
f"Found {len(mode_keys)} keys for mode '{mode}' with pattern '{pattern}'"
|
276 |
+
)
|
277 |
|
278 |
if keys_to_delete:
|
279 |
# Batch delete
|
|
|
305 |
pattern = f"{self.namespace}:*"
|
306 |
cursor = 0
|
307 |
deleted_count = 0
|
308 |
+
|
309 |
while True:
|
310 |
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
311 |
if keys:
|
|
|
315 |
pipe.delete(key)
|
316 |
results = await pipe.execute()
|
317 |
deleted_count += sum(results)
|
318 |
+
|
319 |
if cursor == 0:
|
320 |
break
|
321 |
|
|
|
428 |
try:
|
429 |
async with self._get_redis_connection() as redis:
|
430 |
await redis.ping()
|
431 |
+
logger.info(
|
432 |
+
f"Connected to Redis for doc status namespace {self.namespace}"
|
433 |
+
)
|
434 |
except Exception as e:
|
435 |
logger.error(f"Failed to connect to Redis for doc status: {e}")
|
436 |
raise
|
|
|
486 |
for id in ids:
|
487 |
pipe.get(f"{self.namespace}:{id}")
|
488 |
results = await pipe.execute()
|
489 |
+
|
490 |
for result_data in results:
|
491 |
if result_data:
|
492 |
try:
|
|
|
506 |
# Use SCAN to iterate through all keys in the namespace
|
507 |
cursor = 0
|
508 |
while True:
|
509 |
+
cursor, keys = await redis.scan(
|
510 |
+
cursor, match=f"{self.namespace}:*", count=1000
|
511 |
+
)
|
512 |
if keys:
|
513 |
# Get all values in batch
|
514 |
pipe = redis.pipeline()
|
515 |
for key in keys:
|
516 |
pipe.get(key)
|
517 |
values = await pipe.execute()
|
518 |
+
|
519 |
# Count statuses
|
520 |
for value in values:
|
521 |
if value:
|
|
|
526 |
counts[status] += 1
|
527 |
except json.JSONDecodeError:
|
528 |
continue
|
529 |
+
|
530 |
if cursor == 0:
|
531 |
break
|
532 |
except Exception as e:
|
533 |
logger.error(f"Error getting status counts: {e}")
|
534 |
+
|
535 |
return counts
|
536 |
|
537 |
async def get_docs_by_status(
|
|
|
544 |
# Use SCAN to iterate through all keys in the namespace
|
545 |
cursor = 0
|
546 |
while True:
|
547 |
+
cursor, keys = await redis.scan(
|
548 |
+
cursor, match=f"{self.namespace}:*", count=1000
|
549 |
+
)
|
550 |
if keys:
|
551 |
# Get all values in batch
|
552 |
pipe = redis.pipeline()
|
553 |
for key in keys:
|
554 |
pipe.get(key)
|
555 |
values = await pipe.execute()
|
556 |
+
|
557 |
# Filter by status and create DocProcessingStatus objects
|
558 |
for key, value in zip(keys, values):
|
559 |
if value:
|
|
|
562 |
if doc_data.get("status") == status.value:
|
563 |
# Extract document ID from key
|
564 |
doc_id = key.split(":", 1)[1]
|
565 |
+
|
566 |
# Make a copy of the data to avoid modifying the original
|
567 |
data = doc_data.copy()
|
568 |
# If content is missing, use content_summary as content
|
569 |
+
if (
|
570 |
+
"content" not in data
|
571 |
+
and "content_summary" in data
|
572 |
+
):
|
573 |
data["content"] = data["content_summary"]
|
574 |
# If file_path is not in data, use document id as file path
|
575 |
if "file_path" not in data:
|
576 |
data["file_path"] = "no-file-path"
|
577 |
+
|
578 |
result[doc_id] = DocProcessingStatus(**data)
|
579 |
except (json.JSONDecodeError, KeyError) as e:
|
580 |
+
logger.error(
|
581 |
+
f"Error processing document {key}: {e}"
|
582 |
+
)
|
583 |
continue
|
584 |
+
|
585 |
if cursor == 0:
|
586 |
break
|
587 |
except Exception as e:
|
588 |
logger.error(f"Error getting docs by status: {e}")
|
589 |
+
|
590 |
return result
|
591 |
|
592 |
async def index_done_callback(self) -> None:
|
|
|
597 |
"""Insert or update document status data"""
|
598 |
if not data:
|
599 |
return
|
600 |
+
|
601 |
logger.debug(f"Inserting {len(data)} records to {self.namespace}")
|
602 |
async with self._get_redis_connection() as redis:
|
603 |
try:
|
|
|
622 |
"""Delete specific records from storage by their IDs"""
|
623 |
if not doc_ids:
|
624 |
return
|
625 |
+
|
626 |
async with self._get_redis_connection() as redis:
|
627 |
pipe = redis.pipeline()
|
628 |
for doc_id in doc_ids:
|
629 |
pipe.delete(f"{self.namespace}:{doc_id}")
|
630 |
+
|
631 |
results = await pipe.execute()
|
632 |
deleted_count = sum(results)
|
633 |
+
logger.info(
|
634 |
+
f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}"
|
635 |
+
)
|
636 |
|
637 |
async def drop(self) -> dict[str, str]:
|
638 |
"""Drop all document status data from storage and clean up resources"""
|
|
|
642 |
pattern = f"{self.namespace}:*"
|
643 |
cursor = 0
|
644 |
deleted_count = 0
|
645 |
+
|
646 |
while True:
|
647 |
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
648 |
if keys:
|
|
|
652 |
pipe.delete(key)
|
653 |
results = await pipe.execute()
|
654 |
deleted_count += sum(results)
|
655 |
+
|
656 |
if cursor == 0:
|
657 |
break
|
658 |
|
659 |
+
logger.info(
|
660 |
+
f"Dropped {deleted_count} doc status keys from {self.namespace}"
|
661 |
+
)
|
662 |
return {"status": "success", "message": "data dropped"}
|
663 |
except Exception as e:
|
664 |
logger.error(f"Error dropping doc status {self.namespace}: {e}")
|